40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 if (maskIdx < dimSize)
109 auto maskOperands = m.getOperands();
110 for (
Value operand : maskOperands) {
111 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
126 builder.
create<vector::YieldOp>(loc);
132 switch (combiningKind) {
133 case CombiningKind::ADD:
134 case CombiningKind::MUL:
137 case CombiningKind::MINSI:
138 case CombiningKind::MAXUI:
139 case CombiningKind::MAXSI:
140 case CombiningKind::AND:
141 case CombiningKind::OR:
142 case CombiningKind::XOR:
144 case CombiningKind::MINNUMF:
145 case CombiningKind::MAXNUMF:
146 case CombiningKind::MINIMUMF:
147 case CombiningKind::MAXIMUMF:
148 return llvm::isa<FloatType>(elementType);
154 VectorType vectorType) {
155 int64_t elementVectorRank = 0;
156 VectorType elementVectorType =
157 llvm::dyn_cast<VectorType>(shapedType.getElementType());
158 if (elementVectorType)
159 elementVectorRank += elementVectorType.getRank();
162 if (shapedType.getRank() == 0 &&
168 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169 shapedType.getContext());
176 vector::TransferReadOp read) {
177 auto readMask = read.getMask();
178 auto writeMask = write.getMask();
184 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185 if (!couldBeSameSplat)
190 m_Constant<DenseElementsAttr>(&splatAttr)) ||
202 vector::TransferReadOp read) {
203 return !defWrite.hasOutOfBoundsDim() &&
204 defWrite.getIndices() == read.getIndices() &&
205 defWrite.getVectorType() == read.getVectorType() &&
206 defWrite.getPermutationMap() == read.getPermutationMap() &&
207 ((!defWrite.getMask() && !read.getMask()) ||
212 vector::TransferWriteOp priorWrite) {
213 return priorWrite.getIndices() == write.getIndices() &&
214 priorWrite.getMask() == write.getMask() &&
215 priorWrite.getVectorType() == write.getVectorType() &&
216 priorWrite.getPermutationMap() == write.getPermutationMap();
220 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221 bool testDynamicValueUsingBounds) {
223 if (transferA.getVectorType() != transferB.getVectorType())
225 unsigned rankOffset = transferA.getLeadingShapedRank();
226 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227 Value indexA = transferA.getIndices()[i];
228 Value indexB = transferB.getIndices()[i];
232 if (i < rankOffset) {
235 if (cstIndexA.has_value() && cstIndexB.has_value()) {
236 if (*cstIndexA != *cstIndexB)
240 if (testDynamicValueUsingBounds) {
243 FailureOr<uint64_t> delta =
245 if (succeeded(delta) && *delta != 0)
248 FailureOr<bool> testEqual =
250 if (succeeded(testEqual) && !testEqual.value())
256 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257 if (cstIndexA.has_value() && cstIndexB.has_value()) {
258 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
259 if (distance >= vectorDim)
263 if (testDynamicValueUsingBounds) {
266 FailureOr<int64_t> delta =
268 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
271 FailureOr<int64_t> computeDelta =
273 if (succeeded(computeDelta)) {
274 if (
std::abs(computeDelta.value()) >= vectorDim)
284 VectorTransferOpInterface transferB,
285 bool testDynamicValueUsingBounds) {
286 if (transferA.getSource() != transferB.getSource())
289 testDynamicValueUsingBounds);
299 for (
auto [posInDim, dimSize, offsetInDim] :
300 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302 if (posInDim < dimSize + offsetInDim)
306 posInDim = offsetInDim;
316 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
318 assert(constOp &&
"Unexpected non-constant index");
319 return constOp.value();
329 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
330 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
331 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
341 llvm::transform(foldResults, std::back_inserter(values),
343 if (
auto attr = foldResult.dyn_cast<
Attribute>())
346 loc, cast<IntegerAttr>(attr).getInt())
349 return cast<Value>(foldResult);
360 auto lhs = mul.getLhs();
361 auto rhs = mul.getRhs();
362 if (lhs.getDefiningOp<vector::VectorScaleOp>())
364 if (rhs.getDefiningOp<vector::VectorScaleOp>())
412 void VectorDialect::initialize() {
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
423 addInterfaces<VectorInlinerInterface>();
425 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
428 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
430 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
440 if (isa<ub::PoisonAttrInterface>(value))
443 return arith::ConstantOp::materialize(builder, value, type, loc);
459 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
462 CombiningKind
kind) {
466 reductionDims.push_back(en.index());
467 build(builder, result,
kind, source, acc, reductionDims);
470 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
472 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
477 std::optional<SmallVector<int64_t, 4>>
478 MultiDimReductionOp::getShapeForUnroll() {
479 return llvm::to_vector<4>(getSourceVectorType().
getShape());
485 Type inferredReturnType;
486 auto sourceScalableDims = getSourceVectorType().getScalableDims();
487 for (
auto [dimIdx, dimSize] :
489 if (!llvm::any_of(getReductionDims(),
490 [dimIdx = dimIdx](int64_t reductionDimIdx) {
491 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
493 targetShape.push_back(dimSize);
494 scalableDims.push_back(sourceScalableDims[dimIdx]);
497 if (targetShape.empty())
498 inferredReturnType = getSourceVectorType().getElementType();
501 targetShape, getSourceVectorType().
getElementType(), scalableDims);
502 if (
getType() != inferredReturnType)
503 return emitOpError() <<
"destination type " <<
getType()
504 <<
" is incompatible with source type "
505 << getSourceVectorType();
511 Type MultiDimReductionOp::getExpectedMaskType() {
512 auto vecType = getSourceVectorType();
515 vecType.getScalableDims());
524 struct ElideUnitDimsInMultiDimReduction
528 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
531 for (
const auto &dim :
enumerate(shape)) {
532 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
540 if (reductionOp.isMasked()) {
542 rootOp = reductionOp.getMaskingOp();
543 mask = reductionOp.getMaskingOp().getMask();
545 rootOp = reductionOp;
548 Location loc = reductionOp.getLoc();
549 Value acc = reductionOp.getAcc();
551 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
553 VectorType newMaskType =
555 dstVecType.getScalableDims());
556 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
558 cast = rewriter.
create<vector::ShapeCastOp>(
559 loc, reductionOp.getDestType(), reductionOp.getSource());
564 mask = rewriter.
create<vector::ExtractOp>(loc, mask);
565 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource());
570 cast,
nullptr, mask);
577 void MultiDimReductionOp::getCanonicalizationPatterns(
579 results.
add<ElideUnitDimsInMultiDimReduction>(context);
588 arith::FastMathFlags fastMathFlags) {
589 build(builder, result,
kind, vector,
Value(), fastMathFlags);
594 arith::FastMathFlags fastMathFlags) {
595 build(builder, result,
596 llvm::cast<VectorType>(vector.
getType()).getElementType(),
kind, vector,
602 int64_t rank = getSourceVectorType().getRank();
604 return emitOpError(
"unsupported reduction rank: ") << rank;
607 Type eltType = getDest().getType();
609 return emitOpError(
"unsupported reduction type '")
610 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
619 Type ReductionOp::getExpectedMaskType() {
620 auto vecType = getSourceVectorType();
623 vecType.getScalableDims());
630 case arith::AtomicRMWKind::addf:
631 case arith::AtomicRMWKind::addi:
632 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
633 CombiningKind::ADD, vector);
634 case arith::AtomicRMWKind::mulf:
635 case arith::AtomicRMWKind::muli:
636 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
637 CombiningKind::MUL, vector);
638 case arith::AtomicRMWKind::minimumf:
639 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
640 CombiningKind::MINIMUMF, vector);
641 case arith::AtomicRMWKind::mins:
642 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
643 CombiningKind::MINSI, vector);
644 case arith::AtomicRMWKind::minu:
645 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
647 case arith::AtomicRMWKind::maximumf:
648 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
649 CombiningKind::MAXIMUMF, vector);
650 case arith::AtomicRMWKind::maxs:
651 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
652 CombiningKind::MAXSI, vector);
653 case arith::AtomicRMWKind::maxu:
654 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
655 CombiningKind::MAXUI, vector);
656 case arith::AtomicRMWKind::andi:
657 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
658 CombiningKind::AND, vector);
659 case arith::AtomicRMWKind::ori:
660 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
661 CombiningKind::OR, vector);
670 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
671 return llvm::to_vector<4>(getSourceVectorType().
getShape());
678 LogicalResult matchAndRewrite(ReductionOp reductionOp,
683 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
686 if (maskableOp.isMasked()) {
688 rootOp = maskableOp.getMaskingOp();
689 mask = maskableOp.getMaskingOp().getMask();
691 rootOp = reductionOp;
694 auto vectorType = reductionOp.getSourceVectorType();
695 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
698 Location loc = reductionOp.getLoc();
700 mask = rewriter.
create<ExtractOp>(loc, mask);
701 Value result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector());
703 if (
Value acc = reductionOp.getAcc())
706 reductionOp.getFastmathAttr(), mask);
716 results.
add<ElideSingleElementReduction>(context);
730 getIndexingMapsAttrName(result.
name),
734 getIteratorTypesAttrName(result.
name),
737 return IteratorTypeAttr::get(builder.getContext(), t);
743 ArrayAttr indexingMaps,
744 ArrayAttr iteratorTypes) {
745 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
746 ContractionOp::getDefaultKind());
751 ArrayAttr indexingMaps,
752 ArrayAttr iteratorTypes, CombiningKind
kind) {
769 DictionaryAttr dictAttr;
784 dictAttr.getValue().end());
790 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
792 if (!iteratorTypes) {
794 <<
"expected " << getIteratorTypesAttrName(result.
name)
795 <<
" array attribute";
800 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
801 auto maybeIteratorType = symbolizeIteratorType(s);
802 if (!maybeIteratorType.has_value())
803 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
805 iteratorTypeAttrs.push_back(
813 getKindAttrName(result.
name),
815 ContractionOp::getDefaultKind()));
817 if (masksInfo.empty())
819 if (masksInfo.size() != 2)
821 "expected zero or exactly 2 vector mask operands");
822 auto lhsType = llvm::cast<VectorType>(types[0]);
823 auto rhsType = llvm::cast<VectorType>(types[1]);
825 std::array<VectorType, 2> maskTypes = {
835 auto attrNames = getTraitAttrNames();
837 traitAttrsSet.insert_range(attrNames);
839 for (
auto attr : (*this)->getAttrs()) {
840 if (attr.getName() == getIteratorTypesAttrName()) {
842 llvm::cast<ArrayAttr>(attr.getValue())
843 .getAsValueRange<IteratorTypeAttr, IteratorType>();
849 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
853 attrs.emplace_back(getIteratorTypesAttrName(),
855 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
856 attrs.push_back(attr);
860 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
861 p << getRhs() <<
", " << getAcc();
864 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
869 const std::vector<std::pair<int64_t, int64_t>> &map) {
870 for (
auto &dimPair : map) {
871 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
872 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
873 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
880 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
882 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
883 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
886 for (
auto &dimPair : contractingDimMap) {
887 lhsContractingDimSet.insert(dimPair.first);
888 rhsContractingDimSet.insert(dimPair.second);
891 llvm::make_second_range(batchDimMap));
895 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
896 if (lhsContractingDimSet.count(i) > 0)
898 expectedResultDims.push_back(lhsType.getDimSize(i));
902 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
903 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
905 expectedResultDims.push_back(rhsType.getDimSize(i));
909 if (expectedResultDims.empty()) {
911 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
912 return op.emitOpError(
"invalid accumulator/result vector shape");
915 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
916 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
917 if (!resVectorType || !accVectorType)
918 return op.emitOpError(
"invalid accumulator/result vector shape");
924 AffineMap lhsMap = op.getIndexingMapsArray()[0];
925 AffineMap rhsMap = op.getIndexingMapsArray()[1];
927 return op.emitOpError(
928 "expected all dimensions to be either a LHS or a RHS dimension");
931 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
932 VectorType v = pair.first;
933 auto map = pair.second;
934 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
935 unsigned pos = map.getDimPosition(idx);
940 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
941 return op.emitOpError(
"expected all dimensions to get an extent as "
942 "either a LHS or a RHS dimension");
944 AffineMap resMap = op.getIndexingMapsArray()[2];
950 llvm::IsaPred<AffineConstantExpr>) &&
951 "expected constant extent along all dimensions.");
953 auto expectedShape = llvm::to_vector<4>(
955 return cast<AffineConstantExpr>(e).getValue();
959 resVectorType.getScalableDims());
960 if (resVectorType != expected || accVectorType != expected)
961 return op.emitOpError(
962 "invalid accumulator/result vector shape, expected: ")
969 VectorType lhsType = getLhsType();
970 VectorType rhsType = getRhsType();
971 Type accType = getAccType();
972 Type resType = getResultType();
974 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
975 if (!lhsType.getElementType().isSignlessInteger())
976 return emitOpError(
"only supports signless integer types");
980 if (getIndexingMapsArray().size() != 3)
981 return emitOpError(
"expected an indexing map for each vector operand");
986 unsigned numIterators = getIteratorTypes().getValue().size();
988 auto index = it.index();
989 auto map = it.value();
990 if (map.getNumSymbols() != 0)
991 return emitOpError(
"expected indexing map ")
992 << index <<
" to have no symbols";
993 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
994 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
997 if (map.getNumDims() != numIterators)
998 return emitOpError(
"expected indexing map ")
999 << index <<
" to have " << numIterators <<
" number of inputs";
1000 if (map.getNumResults() != rank)
1001 return emitOpError(
"expected indexing map ")
1002 << index <<
" to have " << rank <<
" number of outputs";
1003 if (!map.isProjectedPermutation())
1004 return emitOpError(
"expected indexing map ")
1005 << index <<
" to be a projected permutation of its inputs";
1008 auto contractingDimMap = getContractingDimMap();
1009 auto batchDimMap = getBatchDimMap();
1012 if (contractingDimMap.empty())
1013 return emitOpError(
"expected at least one contracting dimension pair");
1016 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1017 return emitOpError(
"invalid contracting dimension map");
1021 return emitOpError(
"invalid batch dimension map");
1025 contractingDimMap, batchDimMap)))
1029 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1030 auto elementType = vectorType ? vectorType.getElementType() : resType;
1032 return emitOpError(
"unsupported contraction type");
1041 Type ContractionOp::getExpectedMaskType() {
1042 auto indexingMaps = this->getIndexingMapsArray();
1045 VectorType lhsType = this->getLhsType();
1046 VectorType rhsType = this->getRhsType();
1048 unsigned numVecDims = lhsIdxMap.
getNumDims();
1057 lhsType.getScalableDims()[dimIdx];
1062 rhsType.getScalableDims()[dimIdx];
1065 assert(!ShapedType::isDynamicShape(maskShape) &&
1066 "Mask shape couldn't be computed");
1070 maskShapeScalableDims);
1075 getIteratorTypesAttrName(), getKindAttrName()};
1085 static std::vector<std::pair<int64_t, int64_t>>
1087 IteratorType targetIteratorType,
MLIRContext *context) {
1088 std::vector<std::pair<int64_t, int64_t>> dimMap;
1090 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1091 if (iteratorType != targetIteratorType)
1097 if (lhsDim >= 0 && rhsDim >= 0)
1098 dimMap.emplace_back(lhsDim, rhsDim);
1103 void ContractionOp::getIterationBounds(
1105 auto lhsShape = getLhsType().getShape();
1106 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1112 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1113 if (iteratorType == IteratorType::reduction) {
1115 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1116 assert(lhsDimIndex >= 0);
1117 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1121 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1122 assert(resDimIndex >= 0);
1123 assert(resVectorType !=
nullptr);
1124 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1128 void ContractionOp::getIterationIndexMap(
1130 unsigned numMaps = getIndexingMapsArray().size();
1131 iterationIndexMap.resize(numMaps);
1133 auto index = it.index();
1134 auto map = it.value();
1135 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1136 auto dim = cast<AffineDimExpr>(map.getResult(i));
1137 iterationIndexMap[index][dim.getPosition()] = i;
1142 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1144 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1148 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1150 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1154 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1156 getIterationBounds(shape);
1178 template <
typename AddOpType>
1184 auto canonicalize = [&](
Value maybeContraction,
1185 Value otherOperand) -> vector::ContractionOp {
1186 vector::ContractionOp contractionOp =
1187 dyn_cast_or_null<vector::ContractionOp>(
1190 return vector::ContractionOp();
1191 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1192 contractionOp.getAcc().getDefiningOp())) {
1193 if (maybeZero.getValue() ==
1194 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1196 bvm.
map(contractionOp.getAcc(), otherOperand);
1197 auto newContraction =
1198 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1199 rewriter.
replaceOp(addOp, newContraction.getResult());
1200 return newContraction;
1203 return vector::ContractionOp();
1206 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1207 vector::ContractionOp
contract = canonicalize(a, b);
1209 return contract ? success() : failure();
1225 setResultRanges(getResult(), argRanges.front());
1231 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1235 VectorType vectorType = getSourceVectorType();
1236 if (vectorType.getRank() == 0) {
1238 return emitOpError(
"expected position to be empty with 0-D vector");
1241 if (vectorType.getRank() != 1)
1242 return emitOpError(
"unexpected >1 vector rank");
1244 return emitOpError(
"expected position for 1-D vector");
1248 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1250 if (!adaptor.getPosition())
1254 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1255 return splat.getInput();
1258 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1262 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1263 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1267 auto srcElements = src.getValues<
Attribute>();
1269 uint64_t posIdx = pos.getInt();
1270 if (posIdx >= srcElements.size())
1273 return srcElements[posIdx];
1280 return index == poisonValue || (index >= 0 && index < maxIndex);
1289 setResultRanges(getResult(), argRanges.front());
1294 auto vectorTy = cast<VectorType>(source.
getType());
1299 Value source, int64_t position) {
1319 build(builder, result, source, dynamicPos,
1324 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1325 ExtractOp::Adaptor adaptor,
1327 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1328 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1329 vectorType.getRank()) {
1330 inferredReturnTypes.push_back(vectorType.getElementType());
1332 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1333 vectorType.getRank());
1335 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1336 vectorType.getScalableDims().drop_front(n)));
1344 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1345 return vectorType && vectorType.getShape().equals({1}) &&
1346 vectorType.getElementType() == r.front();
1348 if (l.size() == 1 && r.size() == 1 &&
1349 (isCompatible(l, r) || isCompatible(r, l)))
1356 auto dynamicMarkersCount =
1357 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1358 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1360 "mismatch between dynamic and static positions (kDynamic marker but no "
1361 "corresponding dynamic position) -- this can only happen due to an "
1362 "incorrect fold/rewrite");
1363 auto position = getMixedPosition();
1364 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1366 "expected position attribute of rank no greater than vector rank");
1368 if (
auto attr = dyn_cast<Attribute>(pos)) {
1369 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1371 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1372 return emitOpError(
"expected position attribute #")
1374 <<
" to be a non-negative integer smaller than the "
1375 "corresponding vector dimension or poison (-1)";
1382 template <
typename IntType>
1384 return llvm::to_vector<4>(llvm::map_range(
1385 arrayAttr.getAsRange<IntegerAttr>(),
1386 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1392 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1396 if (extractOp.hasDynamicPosition())
1400 ExtractOp currentOp = extractOp;
1402 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1403 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1406 if (currentOp.hasDynamicPosition())
1409 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1411 extractOp.setOperand(0, currentOp.getVector());
1414 std::reverse(globalPosition.begin(), globalPosition.end());
1415 extractOp.setStaticPosition(globalPosition);
1427 class ExtractFromInsertTransposeChainState {
1429 ExtractFromInsertTransposeChainState(ExtractOp e);
1438 template <
typename ContainerA,
typename ContainerB>
1439 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1440 return a.size() <= b.size() &&
1441 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1448 template <
typename ContainerA,
typename ContainerB>
1449 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1450 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1451 if (elemA < 0 || elemB < 0)
1466 void updateStateForNextIteration(
Value v) {
1473 LogicalResult handleTransposeOp();
1476 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1491 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1496 Value tryToFoldExtractOpInPlace(
Value source);
1498 ExtractOp extractOp;
1500 int64_t extractedRank;
1502 InsertOp nextInsertOp;
1503 TransposeOp nextTransposeOp;
1518 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1520 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1521 extractedRank(extractOp.getNumIndices()) {
1522 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1523 sentinels.reserve(vectorRank - extractedRank);
1524 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1525 sentinels.push_back(-(i + 1));
1527 extractOp.getStaticPosition().end());
1533 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1535 if (extractOp.hasDynamicPosition())
1538 if (!nextTransposeOp)
1541 nextTransposeOp.getPermutation(), extractOp.getContext()));
1548 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1551 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1558 res = nextInsertOp.getSource();
1560 return success(canFold());
1567 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1569 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1582 res = nextInsertOp.getSource();
1590 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1593 if (extractOp.hasDynamicPosition())
1597 bool nothingToFold = (source == extractOp.getVector());
1598 if (nothingToFold || !canFold())
1603 extractOp.setStaticPosition(
1605 extractOp.getVectorMutable().assign(source);
1606 return extractOp.getResult();
1610 Value ExtractFromInsertTransposeChainState::fold() {
1612 if (extractOp.hasDynamicPosition())
1615 Value valueToExtractFrom = extractOp.getVector();
1616 updateStateForNextIteration(valueToExtractFrom);
1617 while (nextInsertOp || nextTransposeOp) {
1620 if (succeeded(handleTransposeOp())) {
1621 valueToExtractFrom = nextTransposeOp.getVector();
1622 updateStateForNextIteration(valueToExtractFrom);
1628 if (succeeded(handleInsertOpWithMatchingPos(result)))
1633 if (succeeded(handleInsertOpWithPrefixPos(result)))
1634 return tryToFoldExtractOpInPlace(result);
1644 valueToExtractFrom = nextInsertOp.getDest();
1645 updateStateForNextIteration(valueToExtractFrom);
1648 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1653 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1654 auto vecType = dyn_cast<VectorType>(type);
1655 return vecType && vecType.getRank() == 0;
1664 Operation *defOp = extractOp.getVector().getDefiningOp();
1665 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1669 if (extractOp.getType() == source.
getType())
1671 auto getRank = [](
Type type) {
1672 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1677 unsigned broadcastSrcRank = getRank(source.
getType());
1678 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1681 unsigned extractResultRank = getRank(extractOp.getType());
1682 if (extractResultRank > broadcastSrcRank)
1685 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1686 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1687 if (extractVecType && broadcastVecType &&
1688 extractVecType.getShape() !=
1689 broadcastVecType.getShape().take_back(extractResultRank))
1692 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1693 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1699 broadcastOp.computeBroadcastedUnitDims();
1702 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1703 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1704 if (broadcastedUnitDims.contains(i))
1708 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1709 extractPos.erase(extractPos.begin(),
1710 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1713 extractOp->setOperands(
1714 llvm::to_vector(llvm::concat<Value>(
ValueRange(source), dynPos)));
1715 extractOp.setStaticPosition(staticPos);
1716 return extractOp.getResult();
1732 if (extractOp.hasDynamicPosition())
1735 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1740 if (shuffleOp.getResultVectorType().getRank() != 1)
1743 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1744 auto shuffleMask = shuffleOp.getMask();
1745 int64_t extractIdx = extractOp.getStaticPosition()[0];
1746 int64_t shuffleIdx = shuffleMask[extractIdx];
1749 if (shuffleIdx < inputVecSize) {
1750 extractOp.setOperand(0, shuffleOp.getV1());
1751 extractOp.setStaticPosition({shuffleIdx});
1753 extractOp.setOperand(0, shuffleOp.getV2());
1754 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1757 return extractOp.getResult();
1763 if (extractOp.hasDynamicPosition())
1766 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1771 auto getDimReverse = [](VectorType type, int64_t n) {
1772 return type.getShape().take_back(n + 1).front();
1774 int64_t destinationRank =
1775 llvm::isa<VectorType>(extractOp.getType())
1776 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1778 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1780 if (destinationRank > 0) {
1781 auto destinationType =
1782 llvm::cast<VectorType>(extractOp.getResult().getType());
1783 for (int64_t i = 0; i < destinationRank; i++) {
1787 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1788 getDimReverse(destinationType, i))
1795 std::reverse(extractedPos.begin(), extractedPos.end());
1798 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1799 strides.push_back(stride);
1801 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1804 int64_t position =
linearize(extractedPos, strides);
1808 int64_t numDimension =
1809 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1811 for (int64_t i = 0; i < numDimension; i++) {
1812 newStrides.push_back(stride);
1814 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1816 std::reverse(newStrides.begin(), newStrides.end());
1820 extractOp.setStaticPosition(newPosition);
1821 extractOp.setOperand(0, shapeCastOp.getSource());
1822 return extractOp.getResult();
1828 if (extractOp.hasDynamicPosition())
1831 auto extractStridedSliceOp =
1832 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1833 if (!extractStridedSliceOp)
1842 if (extractStridedSliceOp.hasNonUnitStrides())
1847 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1848 while (!sliceOffsets.empty()) {
1849 size_t lastOffset = sliceOffsets.size() - 1;
1850 if (sliceOffsets.back() != 0 ||
1851 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1852 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1854 sliceOffsets.pop_back();
1856 unsigned destinationRank = 0;
1857 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1858 destinationRank = vecType.getRank();
1861 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1862 sliceOffsets.size())
1866 assert(extractedPos.size() >= sliceOffsets.size());
1867 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1868 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1869 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1873 extractOp.setStaticPosition(extractedPos);
1874 return extractOp.getResult();
1880 if (extractOp.hasDynamicPosition())
1883 int64_t destinationRank =
1884 llvm::isa<VectorType>(extractOp.getType())
1885 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1887 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1897 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1898 insertOp.getSourceVectorType().getRank();
1899 if (destinationRank > insertOp.getSourceVectorType().getRank())
1901 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1904 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1905 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1908 bool disjoint =
false;
1910 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1911 int64_t start = insertOffsets[dim];
1913 (dim < insertRankDiff)
1915 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1916 int64_t end = start + size;
1917 int64_t offset = extractOffsets[dim];
1919 if (start <= offset && offset < end) {
1920 if (dim >= insertRankDiff)
1921 offsetDiffs.push_back(offset - start);
1931 int64_t srcRankDiff =
1932 insertOp.getSourceVectorType().getRank() - destinationRank;
1933 for (int64_t i = 0; i < destinationRank; i++) {
1934 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1935 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1939 extractOp.getVectorMutable().assign(insertOp.getSource());
1942 extractOp.setStaticPosition(offsetDiffs);
1943 return extractOp.getResult();
1947 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1960 if (extractOp.hasDynamicPosition())
1964 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1965 if (!fromElementsOp)
1969 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1970 if (vecType.isScalable())
1974 int64_t rank = vecType.getRank();
1976 if (extractOp.getType() != vecType.getElementType())
1978 assert(
static_cast<int64_t
>(indices.size()) == rank &&
1979 "unexpected number of indices");
1984 for (
int i = rank - 1; i >= 0; --i) {
1985 flatIndex += indices[i] * stride;
1986 stride *= vecType.getDimSize(i);
1988 return fromElementsOp.getElements()[flatIndex];
1993 template <
typename OpType,
typename AdaptorType>
1996 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
1997 OperandRange dynamicPosition = op.getDynamicPosition();
2001 if (!dynamicPosition.size())
2008 bool opChange =
false;
2009 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2010 if (!ShapedType::isDynamic(staticPosition[i]))
2012 Attribute positionAttr = dynamicPositionAttr[index];
2013 Value position = dynamicPosition[index++];
2014 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2015 staticPosition[i] = attr.getInt();
2019 operands.push_back(position);
2023 op.setStaticPosition(staticPosition);
2024 op.getOperation()->setOperands(operands);
2025 return op.getResult();
2034 int64_t poisonVal) {
2035 if (!is_contained(staticPos, poisonVal))
2043 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2052 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2057 if (denseAttr.isSplat()) {
2059 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2064 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2065 if (vecTy.isScalable())
2068 if (extractOp.hasDynamicPosition()) {
2083 copy(extractOp.getStaticPosition(), completePositions.begin());
2086 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2089 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2091 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2094 newAttr = *denseValuesBegin;
2104 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
2107 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2115 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2144 Operation *defOp = extractOp.getVector().getDefiningOp();
2145 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2149 if (extractOp.getType() == source.
getType())
2151 auto getRank = [](
Type type) {
2152 return llvm::isa<VectorType>(type)
2153 ? llvm::cast<VectorType>(type).getRank()
2156 unsigned broadcastSrcRank = getRank(source.
getType());
2157 unsigned extractResultRank = getRank(extractOp.getType());
2161 if (extractResultRank < broadcastSrcRank)
2165 if (extractResultRank == 0)
2169 extractOp, extractOp.getType(), source);
2182 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2186 VectorType extractedMaskType =
2187 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2189 if (!extractedMaskType)
2192 auto maskOperands = createMaskOp.getOperands();
2194 VectorType maskType = createMaskOp.getVectorType();
2196 bool containsUnknownDims =
false;
2199 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2201 int64_t pos = extractOpPos[dimIdx];
2202 Value operand = maskOperands[dimIdx];
2203 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2206 containsUnknownDims =
true;
2210 int64_t createMaskBound =
2211 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2213 if (pos != ShapedType::kDynamic) {
2216 allFalse |= pos >= createMaskBound;
2217 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2221 containsUnknownDims =
true;
2228 }
else if (!containsUnknownDims) {
2230 extractOp, extractedMaskType,
2231 maskOperands.drop_front(extractOpPos.size()));
2241 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2243 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2247 VectorType sourceType = castOp.getSourceVectorType();
2248 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2252 if (sourceType.getNumElements() != targetType.getNumElements())
2256 castOp.getSource());
2266 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2269 if (extractOp.hasDynamicPosition())
2273 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2278 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2279 if (!fromElementsOp)
2281 VectorType inputType = fromElementsOp.getType();
2284 if (resultType.isScalable() || inputType.isScalable())
2290 llvm::to_vector(extractOp.getStaticPosition());
2291 firstElementPos.append(resultType.getRank(), 0);
2294 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2295 flatIndex += firstElementPos[i] * stride;
2296 stride *= inputType.getDimSize(i);
2301 extractOp, resultType,
2302 fromElementsOp.getElements().slice(flatIndex,
2303 resultType.getNumElements()));
2311 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2312 results.
add(foldExtractFromShapeCastToShapeCast);
2313 results.
add(foldExtractFromFromElements);
2318 for (
auto attr : arrayAttr)
2319 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2326 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2341 if (!llvm::all_equal(fromElementsOp.getElements()))
2344 fromElementsOp.getElements().front());
2359 setResultRanges(getResult(), argRanges.front());
2367 int64_t rankDiff = dstShape.size() - srcShape.size();
2368 int64_t dstDim = rankDiff;
2370 for (
auto [s1, s2] :
2371 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2373 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2383 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2402 Value BroadcastOp::createOrFoldBroadcastOp(
2405 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2409 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2410 if (broadcastedDims.contains(i))
2412 checkShape.push_back(dstShape[i]);
2414 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2415 "ill-formed broadcastedDims contains values not confined to "
2420 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2424 if (!srcVectorType) {
2425 assert(checkShape.empty() &&
2426 "ill-formed createOrFoldBroadcastOp arguments");
2427 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2430 assert(srcVectorType.getShape().equals(checkShape) &&
2431 "ill-formed createOrFoldBroadcastOp arguments");
2442 broadcastShape.reserve(dstShape.size());
2458 int64_t nextSrcShapeDim = broadcastedDims.size();
2459 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2460 if (broadcastedDims.contains(i)) {
2465 broadcastShape.push_back(dstShape[i]);
2466 permutation[i] = broadcastShape.size() - 1;
2472 permutation[i] = nextSrcShapeDim++;
2476 llvm::append_range(broadcastShape, srcVectorType.getShape());
2481 "unexpected \"dim-1\" broadcast");
2483 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2485 vector::BroadcastableToResult::Success &&
2486 "must be broadcastable");
2490 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2491 if (permutation[i] != i)
2492 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2498 Type srcType, VectorType dstVectorType,
2499 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2503 return BroadcastableToResult::Success;
2505 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2507 return BroadcastableToResult::SourceTypeNotAVector;
2509 int64_t srcRank = srcVectorType.getRank();
2510 int64_t dstRank = dstVectorType.getRank();
2511 if (srcRank > dstRank)
2512 return BroadcastableToResult::SourceRankHigher;
2515 int64_t lead = dstRank - srcRank;
2516 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2519 bool foundMismatchingDims =
false;
2522 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2523 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2524 if (srcDim != 1 && srcDim != dstDim)
2525 foundMismatchingDims =
true;
2528 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2529 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2530 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2533 (srcDimScalableFlag != dstDimScalableFlag &&
2534 (srcDim != 1 || srcDimScalableFlag)))
2535 foundMismatchingDims =
true;
2537 if (foundMismatchingDims) {
2538 if (mismatchingDims !=
nullptr) {
2539 mismatchingDims->first.dim = srcDim;
2540 mismatchingDims->first.isScalable = srcDimScalableFlag;
2542 mismatchingDims->second.dim = dstDim;
2543 mismatchingDims->second.isScalable = dstDimScalableFlag;
2545 return BroadcastableToResult::DimensionMismatch;
2549 return BroadcastableToResult::Success;
2553 std::pair<VectorDim, VectorDim> mismatchingDims;
2555 getSourceType(), getResultVectorType(), &mismatchingDims);
2556 if (res == BroadcastableToResult::Success)
2558 if (res == BroadcastableToResult::SourceRankHigher)
2559 return emitOpError(
"source rank higher than destination rank");
2560 if (res == BroadcastableToResult::DimensionMismatch) {
2561 return emitOpError(
"dimension mismatch (")
2562 << (mismatchingDims.first.isScalable ?
"[" :
"")
2563 << mismatchingDims.first.dim
2564 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2565 << (mismatchingDims.second.isScalable ?
"[" :
"")
2566 << mismatchingDims.second.dim
2567 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2569 if (res == BroadcastableToResult::SourceTypeNotAVector)
2570 return emitOpError(
"source type is not a vector");
2571 llvm_unreachable(
"unexpected vector.broadcast op error");
2575 if (getSourceType() == getResultVectorType())
2577 if (!adaptor.getSource())
2579 auto vectorType = getResultVectorType();
2580 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2581 if (vectorType.getElementType() != attr.getType())
2585 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2586 if (vectorType.getElementType() != attr.getType())
2590 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2603 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2607 broadcastOp.getResultVectorType(),
2608 srcBroadcast.getSource());
2618 results.
add<BroadcastFolder>(context);
2626 VectorType resultType = getResultVectorType();
2627 VectorType v1Type = getV1VectorType();
2628 VectorType v2Type = getV2VectorType();
2630 int64_t resRank = resultType.getRank();
2631 int64_t v1Rank = v1Type.getRank();
2632 int64_t v2Rank = v2Type.getRank();
2633 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2634 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2635 if (!wellFormed0DCase && !wellFormedNDCase)
2636 return emitOpError(
"rank mismatch");
2639 for (int64_t r = 1; r < v1Rank; ++r) {
2640 int64_t resDim = resultType.getDimSize(r);
2641 int64_t v1Dim = v1Type.getDimSize(r);
2642 int64_t v2Dim = v2Type.getDimSize(r);
2643 if (resDim != v1Dim || v1Dim != v2Dim)
2644 return emitOpError(
"dimension mismatch");
2648 int64_t maskLength = mask.size();
2649 if (maskLength <= 0)
2650 return emitOpError(
"invalid mask length");
2651 if (maskLength != resultType.getDimSize(0))
2652 return emitOpError(
"mask length mismatch");
2654 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2655 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2658 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2664 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2665 ShuffleOp::Adaptor adaptor,
2667 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2668 auto v1Rank = v1Type.getRank();
2672 shape.reserve(v1Rank);
2673 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2676 llvm::append_range(shape, v1Type.getShape().drop_front());
2677 inferredReturnTypes.push_back(
2682 template <
typename T>
2685 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2686 return value == expected++;
2690 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2691 auto v1Type = getV1VectorType();
2692 auto v2Type = getV2VectorType();
2694 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2695 "Vector shuffle does not support scalable vectors");
2699 if (v1Type.getRank() == 0)
2703 auto mask = getMask();
2710 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2711 if (!v1Attr || !v2Attr)
2715 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2716 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2717 if (isV1Poison && isV2Poison)
2722 if (v1Type.getRank() != 1)
2732 to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2733 poisonElement = v2Elements[0];
2737 to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2738 poisonElement = v1Elements[0];
2742 int64_t v1Size = v1Type.getDimSize(0);
2743 for (int64_t maskIdx : mask) {
2746 if (maskIdx == ShuffleOp::kPoisonIndex) {
2747 indexedElm = poisonElement;
2749 if (maskIdx < v1Size)
2750 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2752 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2755 results.push_back(indexedElm);
2770 VectorType v1VectorType = shuffleOp.getV1VectorType();
2772 if (v1VectorType.getRank() > 0)
2774 if (mask.size() != 1)
2794 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2795 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2797 if (!v1Splat || !v2Splat)
2800 if (v1Splat.getInput() != v2Splat.getInput())
2816 VectorType resultType = op.getResultVectorType();
2817 if (resultType.isScalable())
2819 op,
"ShuffleOp can't represent a scalable interleave");
2821 if (resultType.getRank() != 1)
2823 op,
"ShuffleOp can't represent an n-D interleave");
2825 VectorType sourceType = op.getV1VectorType();
2826 if (sourceType != op.getV2VectorType() ||
2827 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2829 op,
"ShuffleOp types don't match an interleave");
2833 int64_t resultVectorSize = resultType.getNumElements();
2834 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2835 int64_t maskValueA = shuffleMask[i * 2];
2836 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2837 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2839 "ShuffleOp mask not interleaving");
2851 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2861 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2866 build(builder, result, source, dest, {});
2870 auto dstVectorType = getDestVectorType();
2871 if (dstVectorType.getRank() == 0) {
2873 return emitOpError(
"expected position to be empty with 0-D vector");
2876 if (dstVectorType.getRank() != 1)
2877 return emitOpError(
"unexpected >1 vector rank");
2879 return emitOpError(
"expected position for 1-D vector");
2883 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2885 if (!adaptor.getPosition())
2888 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2889 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2890 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2891 if (!src || !dst || !pos)
2897 auto dstElements = dst.getValues<
Attribute>();
2901 uint64_t posIdx = pos.getInt();
2902 if (posIdx >= results.size())
2904 results[posIdx] = src;
2915 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2920 auto vectorTy = cast<VectorType>(dest.
getType());
2921 build(builder, result, source, dest,
2926 Value source,
Value dest, int64_t position) {
2939 posVals.reserve(position.size());
2940 llvm::transform(position, std::back_inserter(posVals),
2942 build(builder, result, source, dest, posVals);
2951 build(builder, result, source, dest, dynamicPos,
2957 auto destVectorType = getDestVectorType();
2958 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2960 "expected position attribute of rank no greater than dest vector rank");
2961 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2962 if (srcVectorType &&
2963 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2964 static_cast<unsigned>(destVectorType.getRank())))
2965 return emitOpError(
"expected position attribute rank + source rank to "
2966 "match dest vector rank");
2967 if (!srcVectorType &&
2968 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2970 "expected position attribute rank to match the dest vector rank");
2972 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2973 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2975 destVectorType.getDimSize(idx))) {
2976 return emitOpError(
"expected position attribute #")
2978 <<
" to be a non-negative integer smaller than the "
2980 "dest vector dimension";
2997 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2998 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2999 srcVecType.getNumElements())
3002 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
3014 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
3015 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3017 if (!srcSplat || !dstSplat)
3020 if (srcSplat.getInput() != dstSplat.getInput())
3033 int64_t maxVectorSizeFoldThreshold) {
3034 if (insertOp.hasDynamicPosition())
3037 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3045 VectorType destTy = insertOp.getDestVectorType();
3046 if (destTy.isScalable())
3050 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3051 !insertOp->hasOneUse())
3057 copy(insertOp.getStaticPosition(), completePositions.begin());
3058 int64_t insertBeginPosition =
3062 Type destEltType = destTy.getElementType();
3067 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3068 if (intAttr.getType() != expectedType)
3077 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3078 for (
auto value : denseSource.getValues<
Attribute>())
3084 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3085 copy(insertedValues, allValues.begin() + insertBeginPosition);
3093 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3096 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3099 constexpr int64_t vectorSizeFoldThreshold = 256;
3103 if (getNumIndices() == 0 && getSourceType() ==
getType())
3109 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3113 vectorSizeFoldThreshold)) {
3139 template <
typename OpType>
3141 ArrayAttr arrayAttr,
3143 StringRef attrName) {
3144 if (arrayAttr.size() > shape.size())
3145 return op.emitOpError(
"expected ")
3146 << attrName <<
" attribute of rank no greater than vector rank";
3153 template <
typename OpType>
3154 static LogicalResult
3156 int64_t
max, StringRef attrName,
3157 bool halfOpen =
true) {
3158 for (
auto attr : arrayAttr) {
3159 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3163 if (val < min || val >= upper)
3164 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3165 <<
min <<
", " << upper <<
")";
3173 template <
typename OpType>
3174 static LogicalResult
3177 bool halfOpen =
true, int64_t
min = 0) {
3178 for (
auto [index, attrDimPair] :
3180 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3181 int64_t
max = std::get<1>(attrDimPair);
3184 if (val < min || val >=
max)
3185 return op.emitOpError(
"expected ")
3186 << attrName <<
" dimension " << index <<
" to be confined to ["
3187 <<
min <<
", " <<
max <<
")";
3197 template <
typename OpType>
3199 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3201 bool halfOpen =
true, int64_t
min = 1) {
3202 assert(arrayAttr1.size() <= shape.size());
3203 assert(arrayAttr2.size() <= shape.size());
3204 for (
auto [index, it] :
3206 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3207 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3208 int64_t
max = std::get<2>(it);
3211 if (val1 + val2 < 0 || val1 + val2 >=
max)
3212 return op.emitOpError(
"expected sum(")
3213 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3214 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3221 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3228 auto sourceVectorType = getSourceVectorType();
3229 auto destVectorType = getDestVectorType();
3230 auto offsets = getOffsetsAttr();
3231 auto strides = getStridesAttr();
3232 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3234 "expected offsets of same size as destination vector rank");
3235 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3236 return emitOpError(
"expected strides of same size as source vector rank");
3237 if (sourceVectorType.getRank() > destVectorType.getRank())
3239 "expected source rank to be no greater than destination rank");
3241 auto sourceShape = sourceVectorType.getShape();
3242 auto destShape = destVectorType.getShape();
3244 destShape.size() - sourceShape.size(), 0);
3245 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3246 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3247 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3256 offName,
"source vector shape",
3260 unsigned rankDiff = destShape.size() - sourceShape.size();
3261 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3262 if (sourceVectorType.getScalableDims()[idx] !=
3263 destVectorType.getScalableDims()[idx + rankDiff]) {
3264 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3267 if (sourceVectorType.getScalableDims()[idx]) {
3268 auto sourceSize = sourceShape[idx];
3269 auto destSize = destShape[idx + rankDiff];
3270 if (sourceSize != destSize) {
3271 return emitOpError(
"expected size at idx=")
3273 << (
" to match the corresponding base size from the input "
3275 << sourceSize << (
" vs ") << destSize << (
")");
3286 class FoldInsertStridedSliceSplat final
3291 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3294 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3296 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3298 if (!srcSplatOp || !destSplatOp)
3301 if (srcSplatOp.getInput() != destSplatOp.getInput())
3304 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3311 class FoldInsertStridedSliceOfExtract final
3316 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3318 auto extractStridedSliceOp =
3319 insertStridedSliceOp.getSource()
3320 .getDefiningOp<vector::ExtractStridedSliceOp>();
3322 if (!extractStridedSliceOp)
3325 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3329 if (extractStridedSliceOp.getStrides() !=
3330 insertStridedSliceOp.getStrides() ||
3331 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3334 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3341 class InsertStridedSliceConstantFolder final
3348 static constexpr int64_t vectorSizeFoldThreshold = 256;
3359 VectorType destTy = destVector.getType();
3360 if (destTy.isScalable())
3364 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3365 !destVector.hasOneUse())
3374 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3378 if (op.hasNonUnitStrides())
3381 VectorType sliceVecTy = sourceValue.getType();
3383 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3393 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3394 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3395 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3396 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3399 currDestPosition.begin() + rankDifference, currDestPosition.end());
3403 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3404 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3405 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3406 "Invalid slice element");
3407 newValues[linearizedPosition] = *sliceValuesIt;
3420 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3422 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3423 InsertStridedSliceConstantFolder>(context);
3426 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3427 if (getSourceVectorType() == getDestVectorType())
3444 p <<
" " << getLhs() <<
", " << getRhs();
3446 p <<
", " << getAcc();
3449 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3460 if (operandsInfo.size() < 2)
3462 "expected at least 2 operands");
3463 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3464 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3467 "expected vector type for operand #1");
3472 vRHS.getScalableDims()[0]};
3474 vLHS.getElementType(), scalableDimsRes);
3478 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3484 OuterProductOp::getKindAttrName(result.
name),
3486 OuterProductOp::getDefaultKind()));
3492 (operandsInfo.size() > 2 &&
3498 Type tRHS = getOperandTypeRHS();
3499 VectorType vLHS = getOperandVectorTypeLHS(),
3500 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3501 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3503 if (vLHS.getRank() != 1)
3504 return emitOpError(
"expected 1-d vector for operand #1");
3508 if (vRHS.getRank() != 1)
3509 return emitOpError(
"expected 1-d vector for operand #2");
3510 if (vRES.getRank() != 2)
3511 return emitOpError(
"expected 2-d vector result");
3512 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3513 return emitOpError(
"expected #1 operand dim to match result dim #1");
3514 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3515 return emitOpError(
"expected #2 operand dim to match result dim #2");
3516 if (vLHS.isScalable() && !vRHS.isScalable()) {
3520 "expected either both or only #2 operand dim to be scalable");
3524 if (vRES.getRank() != 1)
3525 return emitOpError(
"expected 1-d vector result");
3526 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3527 return emitOpError(
"expected #1 operand dim to match result dim #1");
3530 if (vACC && vACC != vRES)
3531 return emitOpError(
"expected operand #3 of same type as result type");
3535 return emitOpError(
"unsupported outerproduct type");
3544 Type OuterProductOp::getExpectedMaskType() {
3545 auto vecType = this->getResultVectorType();
3548 vecType.getScalableDims());
3560 ArrayAttr offsets, ArrayAttr sizes,
3561 ArrayAttr strides) {
3562 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3564 shape.reserve(vectorType.getRank());
3566 for (
unsigned e = offsets.size(); idx < e; ++idx)
3567 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3568 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3569 shape.push_back(vectorType.getShape()[idx]);
3572 vectorType.getScalableDims());
3585 offsetsAttr, sizesAttr, stridesAttr));
3586 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3590 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3595 auto type = getSourceVectorType();
3596 auto offsets = getOffsetsAttr();
3597 auto sizes = getSizesAttr();
3598 auto strides = getStridesAttr();
3599 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3601 "expected offsets, sizes and strides attributes of same size");
3603 auto shape = type.getShape();
3604 auto offName = getOffsetsAttrName();
3605 auto sizesName = getSizesAttrName();
3606 auto stridesName = getStridesAttrName();
3622 shape, offName, sizesName,
3627 offsets, sizes, strides);
3628 if (getResult().
getType() != resultType)
3629 return emitOpError(
"expected result type to be ") << resultType;
3631 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3632 if (type.getScalableDims()[idx]) {
3633 auto inputDim = type.getShape()[idx];
3634 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3635 if (inputDim != inputSize)
3636 return emitOpError(
"expected size at idx=")
3638 << (
" to match the corresponding base size from the input "
3640 << inputSize << (
" vs ") << inputDim << (
")");
3650 static LogicalResult
3653 auto getElement = [](ArrayAttr array,
int idx) {
3654 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3656 ArrayAttr extractOffsets = op.getOffsets();
3658 ArrayAttr extractSizes = op.getSizes();
3659 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3661 if (op.getSourceVectorType().getRank() !=
3662 insertOp.getSourceVectorType().getRank())
3664 ArrayAttr insertOffsets = insertOp.getOffsets();
3665 ArrayAttr insertStrides = insertOp.getStrides();
3668 if (extractOffsets.size() > insertOffsets.size())
3670 bool patialoverlap =
false;
3671 bool disjoint =
false;
3673 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3674 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3676 int64_t start = getElement(insertOffsets, dim);
3677 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3678 int64_t offset = getElement(extractOffsets, dim);
3679 int64_t size = getElement(extractSizes, dim);
3681 if (start <= offset && offset < end) {
3684 if (offset + size > end)
3685 patialoverlap =
true;
3686 offsetDiffs.push_back(offset - start);
3693 if (!disjoint && !patialoverlap) {
3694 op.setOperand(insertOp.getSource());
3703 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3713 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3714 if (getSourceVectorType() == getResult().
getType())
3729 class StridedSliceConstantMaskFolder final
3734 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3738 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3739 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3740 if (!constantMaskOp)
3743 if (extractStridedSliceOp.hasNonUnitStrides())
3756 sliceMaskDimSizes.reserve(maskDimSizes.size());
3757 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3758 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3759 int64_t sliceMaskDimSize =
std::max(
3760 static_cast<int64_t
>(0),
3761 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3762 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3765 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3766 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3767 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3770 if (llvm::is_contained(sliceMaskDimSizes, 0))
3771 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3776 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3783 class StridedSliceSplatConstantFolder final
3788 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3792 Value sourceVector = extractStridedSliceOp.getVector();
3797 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3811 class StridedSliceNonSplatConstantFolder final
3816 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3820 Value sourceVector = extractStridedSliceOp.getVector();
3826 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3827 if (!dense || dense.isSplat())
3831 if (extractStridedSliceOp.hasNonUnitStrides())
3834 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3838 VectorType sliceVecTy = extractStridedSliceOp.getType();
3840 int64_t sliceRank = sliceVecTy.getRank();
3852 auto denseValuesBegin = dense.value_begin<
Attribute>();
3854 sliceValues.reserve(sliceVecTy.getNumElements());
3857 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3858 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3860 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3864 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3865 sliceVecTy.getNumElements() &&
3866 "Invalid number of slice elements");
3876 class StridedSliceBroadcast final
3888 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3889 auto dstVecType = llvm::cast<VectorType>(op.getType());
3890 unsigned dstRank = dstVecType.getRank();
3891 unsigned rankDiff = dstRank - srcRank;
3895 bool lowerDimMatch =
true;
3896 for (
unsigned i = 0; i < srcRank; i++) {
3897 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3898 lowerDimMatch =
false;
3907 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3908 if (!lowerDimMatch && !isScalarSrc) {
3909 source = rewriter.
create<ExtractStridedSliceOp>(
3910 op->getLoc(), source,
3921 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3927 auto splat = op.getVector().getDefiningOp<SplatOp>();
3951 class ContiguousExtractStridedSliceToExtract final
3958 if (op.hasNonUnitStrides())
3960 Value source = op.getOperand();
3961 auto sourceType = cast<VectorType>(source.
getType());
3962 if (sourceType.isScalable() || sourceType.getRank() == 0)
3971 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3972 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3979 if (numOffsets == 0)
3984 if (numOffsets == sourceType.getRank() &&
3985 static_cast<int>(sizes.size()) == sourceType.getRank())
3989 for (
int i = 0; i < numOffsets; ++i) {
3997 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
3998 sizes[numOffsets] == 1) {
4003 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4004 Value extract = rewriter.
create<vector::ExtractOp>(op->getLoc(), source,
4013 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4017 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4018 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
4019 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4029 VectorType vectorType,
Value source,
4030 ValueRange indices, AffineMapAttr permutationMapAttr,
4031 ArrayAttr inBoundsAttr) {
4032 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4033 Value padding = builder.
create<arith::ConstantOp>(
4035 build(builder, result, vectorType, source, indices, permutationMapAttr,
4036 padding,
Value(), inBoundsAttr);
4041 VectorType vectorType,
Value source,
4045 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4049 build(builder, result, vectorType, source, indices, permutationMapAttr,
4055 VectorType vectorType,
Value source,
4059 llvm::cast<ShapedType>(source.
getType()), vectorType);
4061 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4065 build(builder, result, vectorType, source, indices, permutationMapAttr,
4067 Value(), inBoundsAttr);
4073 VectorType vectorType,
Value source,
4076 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4077 Value padding = builder.
create<arith::ConstantOp>(
4079 build(builder, result, vectorType, source, indices, padding, inBounds);
4082 template <
typename EmitFun>
4084 EmitFun emitOpError) {
4086 for (
auto expr : permutationMap.
getResults()) {
4087 auto dim = dyn_cast<AffineDimExpr>(expr);
4088 auto zero = dyn_cast<AffineConstantExpr>(expr);
4090 if (zero.getValue() != 0) {
4092 "requires a projected permutation_map (at most one dim or the zero "
4093 "constant can appear in each result)");
4098 return emitOpError(
"requires a projected permutation_map (at most one "
4099 "dim or the zero constant can appear in each result)");
4101 if (seen[dim.getPosition()]) {
4103 "requires a permutation_map that is a permutation (found one dim "
4104 "used more than once)");
4106 seen[dim.getPosition()] =
true;
4111 static LogicalResult
4113 VectorType vectorType, VectorType maskType,
4114 VectorType inferredMaskType,
AffineMap permutationMap,
4115 ArrayAttr inBounds) {
4116 if (op->hasAttr(
"masked")) {
4117 return op->emitOpError(
"masked attribute has been removed. "
4118 "Use in_bounds instead.");
4121 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4122 return op->emitOpError(
4123 "requires source to be a memref or ranked tensor type");
4125 auto elementType = shapedType.getElementType();
4126 DataLayout dataLayout = DataLayout::closest(op);
4127 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4129 unsigned sourceVecSize =
4131 vectorElementType.getShape().back();
4132 unsigned resultVecSize =
4134 vectorType.getShape().back();
4135 if (resultVecSize % sourceVecSize != 0)
4136 return op->emitOpError(
4137 "requires the bitwidth of the minor 1-D vector to be an integral "
4138 "multiple of the bitwidth of the minor 1-D vector of the source");
4140 unsigned sourceVecEltRank = vectorElementType.getRank();
4141 unsigned resultVecRank = vectorType.getRank();
4142 if (sourceVecEltRank > resultVecRank)
4143 return op->emitOpError(
4144 "requires source vector element and vector result ranks to match.");
4145 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4148 return op->emitOpError(
"requires a permutation_map with result dims of "
4149 "the same rank as the vector type");
4152 return op->emitOpError(
"does not support masks with vector element type");
4155 unsigned minorSize =
4156 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4157 unsigned resultVecSize =
4160 return op->emitOpError(
4161 "requires the bitwidth of the minor 1-D vector to be an integral "
4162 "multiple of the bitwidth of the source element type");
4166 return op->emitOpError(
"requires a permutation_map with result dims of "
4167 "the same rank as the vector type");
4171 return op->emitOpError(
"requires permutation_map without symbols");
4173 if (permutationMap.
getNumInputs() != shapedType.getRank())
4174 return op->emitOpError(
"requires a permutation_map with input dims of the "
4175 "same rank as the source type");
4177 if (maskType && maskType != inferredMaskType)
4178 return op->emitOpError(
"inferred mask type (")
4179 << inferredMaskType <<
") and mask operand type (" << maskType
4182 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4183 return op->emitOpError(
"expects the in_bounds attr of same rank "
4184 "as permutation_map results: ")
4186 <<
" vs inBounds of size: " << inBounds.size();
4193 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4194 if (op.getPermutationMap().isMinorIdentity())
4195 elidedAttrs.push_back(op.getPermutationMapAttrName());
4197 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4198 elidedAttrs.push_back(op.getInBoundsAttrName());
4203 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
4205 p <<
", " << getMask();
4214 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4219 if (maskShape.empty())
4220 maskShape.push_back(1);
4242 if (hasMask.succeeded()) {
4249 if (types.size() != 2)
4250 return parser.
emitError(typesLoc,
"requires two types");
4252 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4253 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4254 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4255 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4257 return parser.
emitError(typesLoc,
"requires vector type");
4258 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4265 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4267 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4269 if (!inBoundsAttr) {
4279 if (hasMask.succeeded()) {
4280 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4282 maskInfo.
location,
"does not support masks with vector element type");
4285 "expected the same rank for the vector and the "
4286 "results of the permutation map");
4294 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4296 {1, static_cast<int32_t>(indexInfo.size()), 1,
4297 static_cast<int32_t>(hasMask.succeeded())}));
4303 ShapedType shapedType = getShapedType();
4305 VectorType maskType = getMaskType();
4306 auto paddingType = getPadding().getType();
4307 auto permutationMap = getPermutationMap();
4308 VectorType inferredMaskType =
4311 auto sourceElementType = shapedType.getElementType();
4313 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4314 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4316 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4317 shapedType, vectorType, maskType,
4318 inferredMaskType, permutationMap, getInBounds())))
4321 if (
auto sourceVectorElementType =
4322 llvm::dyn_cast<VectorType>(sourceElementType)) {
4325 if (sourceVectorElementType != paddingType)
4327 "requires source element type and padding type to match.");
4331 if (!VectorType::isValidElementType(paddingType))
4332 return emitOpError(
"requires valid padding vector elemental type");
4335 if (paddingType != sourceElementType)
4337 "requires formal padding and source of the same elemental type");
4341 [&](Twine t) {
return emitOpError(t); });
4348 Type TransferReadOp::getExpectedMaskType() {
4352 template <
typename TransferOp>
4353 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4356 if (op.getShapedType().isDynamicDim(indicesIdx))
4358 Value index = op.getIndices()[indicesIdx];
4360 if (!cstOp.has_value())
4363 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4364 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4366 return cstOp.value() + vectorSize <= sourceSize;
4369 template <
typename TransferOp>
4373 if (op.getTransferRank() == 0)
4378 newInBounds.reserve(op.getTransferRank());
4383 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4385 if (op.isDimInBounds(i)) {
4386 newInBounds.push_back(
true);
4391 bool inBounds =
false;
4392 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4395 dimExpr.getPosition());
4396 nonBcastDims.push_back(i);
4399 newInBounds.push_back(inBounds);
4407 bool allNonBcastDimsInBounds = llvm::all_of(
4408 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4409 if (allNonBcastDimsInBounds) {
4412 newInBounds[idx] =
true;
4424 template <
typename TransferOp>
4426 auto mask = op.getMask();
4433 op.getMaskMutable().clear();
4447 static Value foldRAW(TransferReadOp readOp) {
4448 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4450 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4453 return defWrite.getVector();
4455 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4456 cast<VectorTransferOpInterface>(readOp.getOperation())))
4458 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4464 if (
Value vec = foldRAW(*
this))
4478 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4482 void TransferReadOp::getEffects(
4485 if (llvm::isa<MemRefType>(getShapedType()))
4491 if (hasPureTensorSemantics())
4519 struct TransferReadAfterWriteToBroadcast
4525 if (readOp.hasOutOfBoundsDim() ||
4526 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4528 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4533 if (readOp.getTransferChunkAccessed() !=
4534 defWrite.getTransferChunkAccessed())
4541 if (readOp.getIndices() != defWrite.getIndices() ||
4542 readOp.getMask() != defWrite.getMask())
4544 Value vec = defWrite.getVector();
4566 broadcastShape[pos.value()] = destShape[pos.index()];
4567 broadcastScalableFlags[pos.value()] =
4568 readOp.getVectorType().getScalableDims()[pos.index()];
4571 broadcastShape, defWrite.getVectorType().getElementType(),
4572 broadcastScalableFlags);
4573 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4584 results.
add<TransferReadAfterWriteToBroadcast>(context);
4594 AffineMapAttr permutationMapAttr,
4596 ArrayAttr inBoundsAttr) {
4597 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4598 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4599 mask, inBoundsAttr);
4605 AffineMapAttr permutationMapAttr,
4606 ArrayAttr inBoundsAttr) {
4607 build(builder, result, vector, dest, indices, permutationMapAttr,
4608 Value(), inBoundsAttr);
4619 (inBounds && !inBounds.value().empty())
4622 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4623 build(builder, result, vector, dest, indices, permutationMapAttr,
4624 Value(), inBoundsAttr);
4632 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4634 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4635 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4651 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4656 if (types.size() != 2)
4657 return parser.
emitError(typesLoc,
"requires two types");
4659 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4661 return parser.
emitError(typesLoc,
"requires vector type");
4662 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4663 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4664 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4665 auto permMapAttrName =
4666 TransferWriteOp::getPermutationMapAttrName(result.
name);
4673 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4675 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4677 if (!inBoundsAttr) {
4686 if (hasMask.succeeded()) {
4687 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4689 maskInfo.
location,
"does not support masks with vector element type");
4692 "expected the same rank for the vector and the "
4693 "results of the permutation map");
4699 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4701 {1, 1, static_cast<int32_t>(indexInfo.size()),
4702 static_cast<int32_t>(hasMask.succeeded())}));
4703 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4708 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4710 p <<
", " << getMask();
4717 ShapedType shapedType = getShapedType();
4719 VectorType maskType = getMaskType();
4720 auto permutationMap = getPermutationMap();
4721 VectorType inferredMaskType =
4725 if (llvm::size(
getIndices()) != shapedType.getRank())
4726 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4730 if (hasBroadcastDim())
4731 return emitOpError(
"should not have broadcast dimensions");
4733 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4734 shapedType, vectorType, maskType,
4735 inferredMaskType, permutationMap, getInBounds())))
4739 [&](Twine t) {
return emitOpError(t); });
4746 Type TransferWriteOp::getExpectedMaskType() {
4767 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4771 if (write.getTransferRank() == 0)
4773 auto rankedTensorType =
4774 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4776 if (!rankedTensorType)
4779 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4783 if (read.getTransferRank() == 0)
4786 if (!read.getPermutationMap().isMinorIdentity() ||
4787 !write.getPermutationMap().isMinorIdentity())
4790 if (read.getTransferRank() != write.getTransferRank())
4793 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4796 if (read.getSource().getType() != rankedTensorType)
4799 if (read.getVectorType() != write.getVectorType())
4802 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4805 auto isNotConstantZero = [](
Value v) {
4807 return !cstOp.has_value() || cstOp.value() != 0;
4809 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4810 llvm::any_of(write.getIndices(), isNotConstantZero))
4813 results.push_back(read.getSource());
4817 static bool checkSameValueWAR(vector::TransferReadOp read,
4818 vector::TransferWriteOp write) {
4819 return read.getSource() == write.getSource() &&
4820 read.getIndices() == write.getIndices() &&
4821 read.getPermutationMap() == write.getPermutationMap() &&
4822 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4839 static LogicalResult foldWAR(TransferWriteOp write,
4841 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4843 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4847 if (!checkSameValueWAR(read, write))
4849 results.push_back(read.getSource());
4853 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4855 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4857 if (succeeded(foldWAR(*
this, results)))
4866 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4870 void TransferWriteOp::getEffects(
4873 if (llvm::isa<MemRefType>(getShapedType()))
4879 if (hasPureTensorSemantics())
4914 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4916 vector::TransferWriteOp writeToModify = writeOp;
4919 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4923 writeToModify.getSourceMutable().assign(defWrite.getSource());
4928 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4929 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4933 if (!defWrite->hasOneUse())
4935 writeToModify = defWrite;
4936 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4965 struct SwapExtractSliceOfTransferWrite
4972 if (!insertOp.hasUnitStride())
4975 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4976 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4978 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4979 if (!transferOp || !transferOp->hasOneUse())
4984 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4986 "use-def chain is rank-reducing");
4990 if (!extractOp.hasZeroOffset()) {
4992 "ExtractSliceOp has non-zero offset");
4996 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5000 "TranferWriteOp has non-zero offset");
5004 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5006 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5009 for (
auto [insertSize, extractSize] :
5010 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5013 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5018 assert(transferOp.getVectorType().hasStaticShape() &&
5019 "expected vector to have a static shape");
5022 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5023 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5025 insertOp,
"TransferWriteOp may not write the full tensor.");
5031 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
5032 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5033 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5034 insertOp.getMixedStrides());
5035 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
5036 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5037 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5040 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5050 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5057 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5059 MemRefType memRefTy) {
5062 if (!vecTy.isScalable() &&
5063 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5066 if (!memRefTy.isLastDimUnitStride())
5067 return op->
emitOpError(
"most minor memref dim must have unit stride");
5075 if (failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5079 Type memElemTy = memRefTy.getElementType();
5080 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5081 if (memVecTy != resVecTy)
5082 return emitOpError(
"base memref and result vector types should match");
5083 memElemTy = memVecTy.getElementType();
5086 if (resVecTy.getElementType() != memElemTy)
5087 return emitOpError(
"base and result element types should match");
5088 if (llvm::size(
getIndices()) != memRefTy.getRank())
5089 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5107 if (failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5111 Type memElemTy = memRefTy.getElementType();
5112 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5113 if (memVecTy != valueVecTy)
5115 "base memref and valueToStore vector types should match");
5116 memElemTy = memVecTy.getElementType();
5119 if (valueVecTy.getElementType() != memElemTy)
5120 return emitOpError(
"base and valueToStore element type should match");
5121 if (llvm::size(
getIndices()) != memRefTy.getRank())
5122 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5126 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5136 VectorType maskVType = getMaskVectorType();
5137 VectorType passVType = getPassThruVectorType();
5141 if (resVType.getElementType() != memType.getElementType())
5142 return emitOpError(
"base and result element type should match");
5143 if (llvm::size(
getIndices()) != memType.getRank())
5144 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5145 if (resVType.getShape() != maskVType.getShape())
5146 return emitOpError(
"expected result shape to match mask shape");
5147 if (resVType != passVType)
5148 return emitOpError(
"expected pass_thru of same type as result type");
5161 load, load.getType(), load.getBase(), load.getIndices());
5164 rewriter.
replaceOp(load, load.getPassThru());
5169 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5176 results.
add<MaskedLoadFolder>(context);
5190 VectorType maskVType = getMaskVectorType();
5194 if (valueVType.getElementType() != memType.getElementType())
5195 return emitOpError(
"base and valueToStore element type should match");
5196 if (llvm::size(
getIndices()) != memType.getRank())
5197 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5198 if (valueVType.getShape() != maskVType.getShape())
5199 return emitOpError(
"expected valueToStore shape to match mask shape");
5212 store, store.getValueToStore(), store.getBase(), store.getIndices());
5220 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5227 results.
add<MaskedStoreFolder>(context);
5230 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5240 VectorType indVType = getIndexVectorType();
5241 VectorType maskVType = getMaskVectorType();
5243 ShapedType baseType = getBaseType();
5245 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5246 return emitOpError(
"requires base to be a memref or ranked tensor type");
5248 if (resVType.getElementType() != baseType.getElementType())
5249 return emitOpError(
"base and result element type should match");
5250 if (llvm::size(
getIndices()) != baseType.getRank())
5251 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5252 if (resVType.getShape() != indVType.getShape())
5253 return emitOpError(
"expected result dim to match indices dim");
5254 if (resVType.getShape() != maskVType.getShape())
5255 return emitOpError(
"expected result dim to match mask dim");
5256 if (resVType != getPassThruVectorType())
5257 return emitOpError(
"expected pass_thru of same type as result type");
5265 Type GatherOp::getExpectedMaskType() {
5266 auto vecType = this->getIndexVectorType();
5269 vecType.getScalableDims());
5272 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5277 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5278 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5279 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5290 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5303 rewriter.
replaceOp(gather, gather.getPassThru());
5308 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5319 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5323 op.getIndices(), op.getMask(),
5332 results.
add<GatherFolder, FoldContiguousGather>(context);
5340 VectorType indVType = getIndexVectorType();
5341 VectorType maskVType = getMaskVectorType();
5345 if (valueVType.getElementType() != memType.getElementType())
5346 return emitOpError(
"base and valueToStore element type should match");
5347 if (llvm::size(
getIndices()) != memType.getRank())
5348 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5349 if (valueVType.getShape() != indVType.getShape())
5350 return emitOpError(
"expected valueToStore dim to match indices dim");
5351 if (valueVType.getShape() != maskVType.getShape())
5352 return emitOpError(
"expected valueToStore dim to match mask dim");
5371 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5382 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5386 op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5394 results.
add<ScatterFolder, FoldContiguousScatter>(context);
5402 VectorType maskVType = getMaskVectorType();
5403 VectorType passVType = getPassThruVectorType();
5407 if (resVType.getElementType() != memType.getElementType())
5408 return emitOpError(
"base and result element type should match");
5409 if (llvm::size(
getIndices()) != memType.getRank())
5410 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5411 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5412 return emitOpError(
"expected result dim to match mask dim");
5413 if (resVType != passVType)
5414 return emitOpError(
"expected pass_thru of same type as result type");
5427 expand, expand.getType(), expand.getBase(), expand.getIndices());
5430 rewriter.
replaceOp(expand, expand.getPassThru());
5435 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5442 results.
add<ExpandLoadFolder>(context);
5450 VectorType maskVType = getMaskVectorType();
5454 if (valueVType.getElementType() != memType.getElementType())
5455 return emitOpError(
"base and valueToStore element type should match");
5456 if (llvm::size(
getIndices()) != memType.getRank())
5457 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5458 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5459 return emitOpError(
"expected valueToStore dim to match mask dim");
5464 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5472 compress, compress.getValueToStore(), compress.getBase(),
5473 compress.getIndices());
5481 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5488 results.
add<CompressStoreFolder>(context);
5497 setResultRanges(getResult(), argRanges.front());
5503 unsigned rankA = a.size();
5504 unsigned rankB = b.size();
5505 assert(rankA < rankB);
5507 auto isOne = [](int64_t v) {
return v == 1; };
5511 if (rankA == 0 && llvm::all_of(b, isOne))
5516 while (i < rankA &&
j < rankB) {
5517 int64_t dimA = a[i];
5519 while (dimB < dimA &&
j < rankB)
5527 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5529 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5533 return i == rankA &&
j == rankB;
5536 static LogicalResult verifyVectorShapeCast(
Operation *op,
5537 VectorType sourceVectorType,
5538 VectorType resultVectorType) {
5540 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5541 return op->
emitOpError(
"source/result vectors must have same element type");
5542 auto sourceShape = sourceVectorType.getShape();
5543 auto resultShape = resultVectorType.getShape();
5546 int64_t sourceDimProduct = std::accumulate(
5547 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5548 int64_t resultDimProduct = std::accumulate(
5549 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5550 if (sourceDimProduct != resultDimProduct)
5551 return op->
emitOpError(
"source/result number of elements must match");
5554 unsigned sourceRank = sourceVectorType.getRank();
5555 unsigned resultRank = resultVectorType.getRank();
5556 if (sourceRank < resultRank) {
5557 if (!isValidShapeCast(sourceShape, resultShape))
5559 }
else if (sourceRank > resultRank) {
5560 if (!isValidShapeCast(resultShape, sourceShape))
5565 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5566 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5567 if (sourceNScalableDims != resultNScalableDims)
5568 return op->
emitOpError(
"different number of scalable dims at source (")
5569 << sourceNScalableDims <<
") and result (" << resultNScalableDims
5571 sourceVectorType.getNumDynamicDims();
5577 auto sourceVectorType =
5578 llvm::dyn_cast_or_null<VectorType>(getSource().
getType());
5579 auto resultVectorType =
5580 llvm::dyn_cast_or_null<VectorType>(getResult().
getType());
5583 if (sourceVectorType && resultVectorType)
5584 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5595 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5596 if (getResult().
getType() == otherOp.getSource().getType())
5597 return otherOp.getSource();
5600 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5601 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
5602 if (srcType.getRank() < resultType.getRank()) {
5603 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5605 }
else if (srcType.getRank() > resultType.getRank()) {
5606 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5612 setOperand(otherOp.getSource());
5617 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5618 if (bcastOp.getSourceType() ==
getType())
5619 return bcastOp.getSource();
5627 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5634 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5638 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5654 static VectorType trimTrailingOneDims(VectorType oldType) {
5661 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5662 newShape = newShape.drop_back(1);
5663 newScalableDims = newScalableDims.drop_back(1);
5668 if (newShape.empty()) {
5669 newShape = oldShape.take_back();
5670 newScalableDims = oldScalableDims.take_back();
5673 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5688 class ShapeCastCreateMaskFolderTrailingOneDim final
5695 Value shapeOpSrc = shapeOp->getOperand(0);
5696 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5697 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5698 if (!createMaskOp && !constantMaskOp)
5701 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5702 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5704 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5705 if (newVecType != shapeOpResTy)
5708 auto numDimsToDrop =
5709 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5716 auto maskOperands = createMaskOp.getOperands();
5717 auto numMaskOperands = maskOperands.size();
5720 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5722 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5723 if (!constant || (constant.value() != 1))
5727 maskOperands.drop_back(numDimsToDrop);
5734 if (constantMaskOp) {
5735 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5736 auto numMaskOperands = maskDimSizes.size();
5739 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5741 if (maskDimSizes[i] != 1)
5745 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5760 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5767 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5772 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5773 broadcastSourceShape = srcType.getShape();
5775 shapeCastOp.getResultVectorType().getShape();
5779 if (broadcastSourceShape ==
5780 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5782 shapeCastOp, shapeCastOp.getResultVectorType(),
5783 broadcastOp.getSource());
5789 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5790 if (srcType.getNumElements() ==
5791 shapeCastOp.getResultVectorType().getNumElements()) {
5793 shapeCastOp, shapeCastOp.getResultVectorType(),
5794 broadcastOp.getSource());
5807 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5808 ShapeCastBroadcastFolder>(context);
5816 auto sourceVectorType = getSourceVectorType();
5817 auto resultVectorType = getResultVectorType();
5819 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5820 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5821 return emitOpError(
"dimension size mismatch at: ") << i;
5824 DataLayout dataLayout = DataLayout::closest(*
this);
5825 auto sourceElementBits =
5827 auto resultElementBits =
5830 if (sourceVectorType.getRank() == 0) {
5831 if (sourceElementBits != resultElementBits)
5832 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5833 "types must be equal");
5834 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5835 resultElementBits * resultVectorType.getShape().back()) {
5837 "source/result bitwidth of the minor 1-D vectors must be equal");
5849 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5850 if (getResult().
getType() == otherOp.getSource().getType())
5851 return otherOp.getSource();
5853 setOperand(otherOp.getSource());
5857 Attribute sourceConstant = adaptor.getSource();
5858 if (!sourceConstant)
5861 Type srcElemType = getSourceVectorType().getElementType();
5862 Type dstElemType = getResultVectorType().getElementType();
5864 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5865 if (floatPack.isSplat()) {
5866 auto splat = floatPack.getSplatValue<FloatAttr>();
5869 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5870 uint32_t bits =
static_cast<uint32_t
>(
5871 splat.getValue().bitcastToAPInt().getZExtValue());
5873 bits = (bits << 16) | (bits & 0xffff);
5874 APInt intBits(32, bits);
5875 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5881 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5882 if (intPack.isSplat()) {
5883 auto splat = intPack.getSplatValue<IntegerAttr>();
5885 if (llvm::isa<IntegerType>(dstElemType)) {
5890 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5891 APInt intBits = splat.getValue().zext(dstBitWidth);
5894 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5895 intBits = (intBits << srcBitWidth) | intBits;
5910 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5913 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5922 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5923 VectorType vectorType =
5927 memRefType.getMemorySpace()));
5931 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
5932 if (!canonicalType.getLayout().isIdentity())
5933 return emitOpError(
"expects operand to be a memref with identity layout");
5934 if (!getResultMemRefType().getLayout().isIdentity())
5935 return emitOpError(
"expects result to be a memref with identity layout");
5936 if (getResultMemRefType().getMemorySpace() !=
5938 return emitOpError(
"expects result in same memory space");
5941 auto resultType = getResultMemRefType();
5945 "expects result and operand with same underlying scalar type: ")
5947 if (extractShape(sourceType) != extractShape(resultType))
5949 "expects concatenated result and operand shapes to be equal: ")
5960 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5963 for (
unsigned i = 0; i < permutation.size(); ++i) {
5964 transposedShape[i] = vt.getShape()[permutation[i]];
5965 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5970 transposedScalableDims));
5975 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5978 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5980 return attr.reshape(getResultVectorType());
5988 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5997 VectorType vectorType = getSourceVectorType();
5998 VectorType resultType = getResultVectorType();
5999 int64_t rank = resultType.getRank();
6000 if (vectorType.getRank() != rank)
6001 return emitOpError(
"vector result rank mismatch: ") << rank;
6004 int64_t size = perm.size();
6006 return emitOpError(
"transposition length mismatch: ") << size;
6009 if (ta.value() < 0 || ta.value() >= rank)
6010 return emitOpError(
"transposition index out of range: ") << ta.value();
6011 if (seen[ta.value()])
6012 return emitOpError(
"duplicate position index: ") << ta.value();
6013 seen[ta.value()] =
true;
6014 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6015 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6020 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6021 return llvm::to_vector<4>(getResultVectorType().
getShape());
6027 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6037 for (
auto index : permutation2)
6038 result.push_back(permutation1[index]);
6043 vector::TransposeOp parentTransposeOp =
6044 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6045 if (!parentTransposeOp)
6049 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6052 transposeOp, transposeOp.getResult().getType(),
6053 parentTransposeOp.getVector(), permutation);
6059 struct FoldTransposedScalarBroadcast final
6065 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6069 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6070 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6072 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6087 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6092 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6098 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6104 Value transposeSrc = transpOp.getVector();
6105 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6106 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6107 if (!createMaskOp && !constantMaskOp)
6115 auto maskOperands = createMaskOp.getOperands();
6120 transpOp, transpOp.getResultVectorType(), newOperands);
6125 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6129 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6136 void vector::TransposeOp::getCanonicalizationPatterns(
6138 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6139 TransposeFolder, FoldTransposeSplat>(context);
6148 assert(
kind == ConstantMaskKind::AllTrue ||
6149 kind == ConstantMaskKind::AllFalse);
6150 build(builder, result, type,
6151 kind == ConstantMaskKind::AllTrue
6157 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6159 if (resultType.getRank() == 0) {
6160 if (getMaskDimSizes().size() != 1)
6161 return emitError(
"array attr must have length 1 for 0-D vectors");
6162 auto dim = getMaskDimSizes()[0];
6163 if (dim != 0 && dim != 1)
6164 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6169 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6171 "must specify array attr of size equal vector result rank");
6174 auto resultShape = resultType.getShape();
6175 auto resultScalableDims = resultType.getScalableDims();
6177 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6178 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6180 "array attr of size out of bounds of vector result dimension size");
6181 if (resultScalableDims[index] && maskDimSize != 0 &&
6182 maskDimSize != resultShape[index])
6184 "only supports 'none set' or 'all set' scalable dimensions");
6188 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6189 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6190 if (anyZeros && !allZeros)
6191 return emitOpError(
"expected all mask dim sizes to be zeros, "
6192 "as a result of conjunction with zero mask dim");
6196 bool ConstantMaskOp::isAllOnesMask() {
6199 if (resultType.getRank() == 0) {
6200 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6201 return getMaskDimSizes()[0] == 1;
6203 for (
const auto [resultSize, maskDimSize] :
6204 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6205 if (maskDimSize < resultSize)
6220 build(builder, result, type, operands);
6224 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6226 if (vectorType.getRank() == 0) {
6227 if (getNumOperands() != 1)
6229 "must specify exactly one operand for 0-D create_mask");
6230 }
else if (getNumOperands() !=
6231 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6233 "must specify an operand for each result vector dimension");
6269 VectorType maskType = createMaskOp.getVectorType();
6271 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6274 constexpr std::array<int64_t, 1> rankZeroShape{1};
6275 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6276 if (maskType.getRank() == 0) {
6277 maskTypeDimSizes = rankZeroShape;
6278 maskTypeDimScalableFlags = rankZeroScalableDims;
6284 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6289 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6291 constantDims.push_back(*intSize);
6295 if (vscaleMultiplier < maskTypeDimSizes[i])
6297 constantDims.push_back(*vscaleMultiplier);
6304 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6305 value = std::clamp<int64_t>(value, 0, maskDimSize);
6308 if (llvm::is_contained(constantDims, 0))
6309 constantDims.assign(constantDims.size(), 0);
6322 results.
add<CreateMaskFolder>(context);
6333 assert(maskRegionBuilder &&
6334 "builder callback for 'maskRegion' must be present");
6340 maskRegionBuilder(builder, maskableOp);
6347 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6355 build(builder, result, mask, maskableOp, maskRegionBuilder);
6376 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6383 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6397 result.
types.append(resultTypes);
6403 if (parsePassthru.succeeded())
6411 p <<
" " << getMask();
6413 p <<
", " << getPassthru();
6417 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6424 p <<
" : " << getMask().getType();
6425 if (getNumResults() > 0)
6426 p <<
" -> " << getResultTypes();
6431 MaskOp>::ensureTerminator(region, builder, loc);
6443 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
6446 if (maskedOp == oldYieldOp)
6449 opBuilder.setInsertionPoint(oldYieldOp);
6450 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
6452 oldYieldOp->
erase();
6457 Block &block = getMaskRegion().getBlocks().
front();
6459 return emitOpError(
"expects a terminator within the mask region");
6462 if (numMaskRegionOps > 2)
6463 return emitOpError(
"expects only one operation to mask");
6466 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6468 return emitOpError(
"expects a terminator within the mask region");
6470 if (terminator->getNumOperands() != getNumResults())
6472 "expects number of results to match mask region yielded values");
6475 if (numMaskRegionOps == 1)
6478 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6480 return emitOpError(
"expects a MaskableOpInterface within the mask region");
6484 return emitOpError(
"expects number of results to match maskable operation "
6485 "number of results");
6487 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6489 "expects result type to match maskable operation result type");
6492 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6493 return emitOpError(
"multiple vector results not supported");
6496 Type expectedMaskType = maskableOp.getExpectedMaskType();
6497 if (getMask().
getType() != expectedMaskType)
6498 return emitOpError(
"expects a ")
6499 << expectedMaskType <<
" mask for the maskable operation";
6502 Value passthru = getPassthru();
6504 if (!maskableOp.supportsPassthru())
6506 "doesn't expect a passthru argument for this maskable operation");
6509 return emitOpError(
"expects result when passthru argument is provided");
6512 return emitOpError(
"expects passthru type to match result type");
6519 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6529 Operation *maskableOp = getMaskableOp();
6533 llvm::append_range(results, maskableOp->
getResults());
6545 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6546 if (maskingOp.getMaskableOp())
6549 if (!maskOp.isEmpty())
6552 Block *block = maskOp.getMaskBlock();
6553 auto terminator = cast<vector::YieldOp>(block->
front());
6554 if (terminator.getNumOperands() == 0)
6557 rewriter.
replaceOp(maskOp, terminator.getOperands());
6565 results.
add<ElideEmptyMaskOp>(context);
6572 Block *block = getMaskBlock();
6576 return &block->
front();
6580 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
6587 VectorType srcType = getSourceType();
6588 VectorType initialType = getInitialValueType();
6590 int64_t srcRank = srcType.getRank();
6591 int64_t reductionDim = getReductionDim();
6592 if (reductionDim >= srcRank)
6593 return emitOpError(
"reduction dimension ")
6594 << reductionDim <<
" has to be less than " << srcRank;
6597 int64_t initialValueRank = initialType.getRank();
6598 if (initialValueRank != srcRank - 1)
6599 return emitOpError(
"initial value rank ")
6600 << initialValueRank <<
" has to be equal to " << srcRank - 1;
6606 for (
int i = 0; i < srcRank; i++) {
6607 if (i != reductionDim)
6608 expectedShape.push_back(srcShape[i]);
6610 if (!llvm::equal(initialValueShapes, expectedShape)) {
6611 return emitOpError(
"incompatible input/initial value shapes");
6615 Type eltType = getDestType().getElementType();
6617 return emitOpError(
"unsupported reduction type ")
6618 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
6627 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6628 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6629 StridedSliceConstantMaskFolder, TransposeFolder>(
6638 auto constOperand = adaptor.getInput();
6639 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6648 setResultRanges(getResult(), argRanges.front());
6653 arith::FastMathFlagsAttr fastmath,
6660 case CombiningKind::ADD:
6663 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6664 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6666 llvm_unreachable(
"invalid value types for ADD reduction");
6668 case CombiningKind::AND:
6672 case CombiningKind::MAXNUMF:
6673 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6674 "expected float values");
6675 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6677 case CombiningKind::MAXIMUMF:
6678 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6679 "expected float values");
6680 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6682 case CombiningKind::MINNUMF:
6683 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6684 "expected float values");
6685 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6687 case CombiningKind::MINIMUMF:
6688 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6689 "expected float values");
6690 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6692 case CombiningKind::MAXSI:
6696 case CombiningKind::MINSI:
6700 case CombiningKind::MAXUI:
6708 case CombiningKind::MUL:
6711 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6712 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6714 llvm_unreachable(
"invalid value types for MUL reduction");
6716 case CombiningKind::OR:
6720 case CombiningKind::XOR:
6726 assert(result &&
"unknown CombiningKind");
6738 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
6758 return builder.
create<MaskOp>(maskableOp->getLoc(),
6759 maskableOp->getResultTypes(), mask, maskableOp,
6776 mask, newValue, passthru);
6783 #define GET_ATTRDEF_CLASSES
6784 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6786 #define GET_OP_CLASSES
6787 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.