40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 if (maskIdx < dimSize)
109 auto maskOperands = m.getOperands();
110 for (
Value operand : maskOperands) {
111 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
126 builder.
create<vector::YieldOp>(loc);
132 switch (combiningKind) {
133 case CombiningKind::ADD:
134 case CombiningKind::MUL:
137 case CombiningKind::MINSI:
138 case CombiningKind::MAXUI:
139 case CombiningKind::MAXSI:
140 case CombiningKind::AND:
141 case CombiningKind::OR:
142 case CombiningKind::XOR:
144 case CombiningKind::MINNUMF:
145 case CombiningKind::MAXNUMF:
146 case CombiningKind::MINIMUMF:
147 case CombiningKind::MAXIMUMF:
148 return llvm::isa<FloatType>(elementType);
154 VectorType vectorType) {
155 int64_t elementVectorRank = 0;
156 VectorType elementVectorType =
157 llvm::dyn_cast<VectorType>(shapedType.getElementType());
158 if (elementVectorType)
159 elementVectorRank += elementVectorType.getRank();
162 if (shapedType.getRank() == 0 &&
168 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169 shapedType.getContext());
176 vector::TransferReadOp read) {
177 auto readMask = read.getMask();
178 auto writeMask = write.getMask();
184 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185 if (!couldBeSameSplat)
190 m_Constant<DenseElementsAttr>(&splatAttr)) ||
202 vector::TransferReadOp read) {
203 return !defWrite.hasOutOfBoundsDim() &&
204 defWrite.getIndices() == read.getIndices() &&
205 defWrite.getVectorType() == read.getVectorType() &&
206 defWrite.getPermutationMap() == read.getPermutationMap() &&
207 ((!defWrite.getMask() && !read.getMask()) ||
212 vector::TransferWriteOp priorWrite) {
213 return priorWrite.getIndices() == write.getIndices() &&
214 priorWrite.getMask() == write.getMask() &&
215 priorWrite.getVectorType() == write.getVectorType() &&
216 priorWrite.getPermutationMap() == write.getPermutationMap();
220 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221 bool testDynamicValueUsingBounds) {
223 if (transferA.getVectorType() != transferB.getVectorType())
225 unsigned rankOffset = transferA.getLeadingShapedRank();
226 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227 Value indexA = transferA.getIndices()[i];
228 Value indexB = transferB.getIndices()[i];
232 if (i < rankOffset) {
235 if (cstIndexA.has_value() && cstIndexB.has_value()) {
236 if (*cstIndexA != *cstIndexB)
240 if (testDynamicValueUsingBounds) {
243 FailureOr<uint64_t> delta =
245 if (succeeded(delta) && *delta != 0)
248 FailureOr<bool> testEqual =
250 if (succeeded(testEqual) && !testEqual.value())
256 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257 if (cstIndexA.has_value() && cstIndexB.has_value()) {
258 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
259 if (distance >= vectorDim)
263 if (testDynamicValueUsingBounds) {
266 FailureOr<int64_t> delta =
268 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
271 FailureOr<int64_t> computeDelta =
273 if (succeeded(computeDelta)) {
274 if (
std::abs(computeDelta.value()) >= vectorDim)
284 VectorTransferOpInterface transferB,
285 bool testDynamicValueUsingBounds) {
286 if (transferA.getSource() != transferB.getSource())
289 testDynamicValueUsingBounds);
299 for (
auto [posInDim, dimSize, offsetInDim] :
300 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302 if (posInDim < dimSize + offsetInDim)
306 posInDim = offsetInDim;
316 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
318 assert(constOp &&
"Unexpected non-constant index");
319 return constOp.value();
329 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
330 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
331 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
341 llvm::transform(foldResults, std::back_inserter(values),
343 if (
auto attr = foldResult.dyn_cast<
Attribute>())
346 loc, cast<IntegerAttr>(attr).getInt())
349 return cast<Value>(foldResult);
360 auto lhs = mul.getLhs();
361 auto rhs = mul.getRhs();
362 if (lhs.getDefiningOp<vector::VectorScaleOp>())
364 if (rhs.getDefiningOp<vector::VectorScaleOp>())
412 void VectorDialect::initialize() {
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
423 addInterfaces<VectorInlinerInterface>();
425 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
428 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
430 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
440 if (isa<ub::PoisonAttrInterface>(value))
443 return arith::ConstantOp::materialize(builder, value, type, loc);
459 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
462 CombiningKind kind) {
466 reductionDims.push_back(en.index());
467 build(builder, result, kind, source, acc, reductionDims);
470 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
472 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
477 std::optional<SmallVector<int64_t, 4>>
478 MultiDimReductionOp::getShapeForUnroll() {
479 return llvm::to_vector<4>(getSourceVectorType().
getShape());
485 Type inferredReturnType;
486 auto sourceScalableDims = getSourceVectorType().getScalableDims();
487 for (
auto [dimIdx, dimSize] :
489 if (!llvm::any_of(getReductionDims(),
490 [dimIdx = dimIdx](int64_t reductionDimIdx) {
491 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
493 targetShape.push_back(dimSize);
494 scalableDims.push_back(sourceScalableDims[dimIdx]);
497 if (targetShape.empty())
498 inferredReturnType = getSourceVectorType().getElementType();
501 targetShape, getSourceVectorType().
getElementType(), scalableDims);
502 if (
getType() != inferredReturnType)
503 return emitOpError() <<
"destination type " <<
getType()
504 <<
" is incompatible with source type "
505 << getSourceVectorType();
511 Type MultiDimReductionOp::getExpectedMaskType() {
512 auto vecType = getSourceVectorType();
515 vecType.getScalableDims());
524 struct ElideUnitDimsInMultiDimReduction
528 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
531 for (
const auto &dim :
enumerate(shape)) {
532 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
540 if (reductionOp.isMasked()) {
542 rootOp = reductionOp.getMaskingOp();
543 mask = reductionOp.getMaskingOp().getMask();
545 rootOp = reductionOp;
548 Location loc = reductionOp.getLoc();
549 Value acc = reductionOp.getAcc();
551 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
553 VectorType newMaskType =
555 dstVecType.getScalableDims());
556 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
558 cast = rewriter.
create<vector::ShapeCastOp>(
559 loc, reductionOp.getDestType(), reductionOp.getSource());
565 mask = rewriter.
create<vector::ExtractOp>(loc, mask, zeroIdx);
566 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource(),
572 cast,
nullptr, mask);
579 void MultiDimReductionOp::getCanonicalizationPatterns(
581 results.
add<ElideUnitDimsInMultiDimReduction>(context);
589 CombiningKind kind,
Value vector,
590 arith::FastMathFlags fastMathFlags) {
591 build(builder, result, kind, vector,
Value(), fastMathFlags);
596 arith::FastMathFlags fastMathFlags) {
597 build(builder, result,
598 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
604 int64_t rank = getSourceVectorType().getRank();
606 return emitOpError(
"unsupported reduction rank: ") << rank;
609 Type eltType = getDest().getType();
611 return emitOpError(
"unsupported reduction type '")
612 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
621 Type ReductionOp::getExpectedMaskType() {
622 auto vecType = getSourceVectorType();
625 vecType.getScalableDims());
632 case arith::AtomicRMWKind::addf:
633 case arith::AtomicRMWKind::addi:
634 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
635 CombiningKind::ADD, vector);
636 case arith::AtomicRMWKind::mulf:
637 case arith::AtomicRMWKind::muli:
638 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
639 CombiningKind::MUL, vector);
640 case arith::AtomicRMWKind::minimumf:
641 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
642 CombiningKind::MINIMUMF, vector);
643 case arith::AtomicRMWKind::mins:
644 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
645 CombiningKind::MINSI, vector);
646 case arith::AtomicRMWKind::minu:
647 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
649 case arith::AtomicRMWKind::maximumf:
650 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
651 CombiningKind::MAXIMUMF, vector);
652 case arith::AtomicRMWKind::maxs:
653 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
654 CombiningKind::MAXSI, vector);
655 case arith::AtomicRMWKind::maxu:
656 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
657 CombiningKind::MAXUI, vector);
658 case arith::AtomicRMWKind::andi:
659 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
660 CombiningKind::AND, vector);
661 case arith::AtomicRMWKind::ori:
662 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
663 CombiningKind::OR, vector);
672 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
673 return llvm::to_vector<4>(getSourceVectorType().
getShape());
680 LogicalResult matchAndRewrite(ReductionOp reductionOp,
685 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
688 if (maskableOp.isMasked()) {
690 rootOp = maskableOp.getMaskingOp();
691 mask = maskableOp.getMaskingOp().getMask();
693 rootOp = reductionOp;
696 auto vectorType = reductionOp.getSourceVectorType();
697 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
700 Location loc = reductionOp.getLoc();
702 if (vectorType.getRank() == 0) {
704 mask = rewriter.
create<ExtractElementOp>(loc, mask);
705 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
708 mask = rewriter.
create<ExtractOp>(loc, mask, 0);
709 result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector(), 0);
712 if (
Value acc = reductionOp.getAcc())
715 reductionOp.getFastmathAttr(), mask);
725 results.
add<ElideSingleElementReduction>(context);
739 getIndexingMapsAttrName(result.
name),
743 getIteratorTypesAttrName(result.
name),
746 return IteratorTypeAttr::get(builder.getContext(), t);
752 ArrayAttr indexingMaps,
753 ArrayAttr iteratorTypes) {
754 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
755 ContractionOp::getDefaultKind());
760 ArrayAttr indexingMaps,
761 ArrayAttr iteratorTypes, CombiningKind kind) {
778 DictionaryAttr dictAttr;
793 dictAttr.getValue().end());
799 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
804 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
805 auto maybeIteratorType = symbolizeIteratorType(s);
806 if (!maybeIteratorType.has_value())
807 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
809 iteratorTypeAttrs.push_back(
817 getKindAttrName(result.
name),
819 ContractionOp::getDefaultKind()));
821 if (masksInfo.empty())
823 if (masksInfo.size() != 2)
825 "expected zero or exactly 2 vector mask operands");
826 auto lhsType = llvm::cast<VectorType>(types[0]);
827 auto rhsType = llvm::cast<VectorType>(types[1]);
829 std::array<VectorType, 2> maskTypes = {
839 auto attrNames = getTraitAttrNames();
841 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
843 for (
auto attr : (*this)->getAttrs()) {
844 if (attr.getName() == getIteratorTypesAttrName()) {
846 llvm::cast<ArrayAttr>(attr.getValue())
847 .getAsValueRange<IteratorTypeAttr, IteratorType>();
853 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
857 attrs.emplace_back(getIteratorTypesAttrName(),
859 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
860 attrs.push_back(attr);
864 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
865 p << getRhs() <<
", " << getAcc();
868 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
873 const std::vector<std::pair<int64_t, int64_t>> &map) {
874 for (
auto &dimPair : map) {
875 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
876 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
877 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
884 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
886 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
887 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
890 for (
auto &dimPair : contractingDimMap) {
891 lhsContractingDimSet.insert(dimPair.first);
892 rhsContractingDimSet.insert(dimPair.second);
895 for (
auto &dimPair : batchDimMap)
896 rhsBatchDimSet.insert(dimPair.second);
900 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
901 if (lhsContractingDimSet.count(i) > 0)
903 expectedResultDims.push_back(lhsType.getDimSize(i));
907 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
908 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
910 expectedResultDims.push_back(rhsType.getDimSize(i));
914 if (expectedResultDims.empty()) {
916 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
917 return op.emitOpError(
"invalid accumulator/result vector shape");
920 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
921 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
922 if (!resVectorType || !accVectorType)
923 return op.emitOpError(
"invalid accumulator/result vector shape");
929 AffineMap lhsMap = op.getIndexingMapsArray()[0];
930 AffineMap rhsMap = op.getIndexingMapsArray()[1];
932 return op.emitOpError(
933 "expected all dimensions to be either a LHS or a RHS dimension");
936 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
937 VectorType v = pair.first;
938 auto map = pair.second;
939 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
940 unsigned pos = map.getDimPosition(idx);
945 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
946 return op.emitOpError(
"expected all dimensions to get an extent as "
947 "either a LHS or a RHS dimension");
949 AffineMap resMap = op.getIndexingMapsArray()[2];
955 llvm::IsaPred<AffineConstantExpr>) &&
956 "expected constant extent along all dimensions.");
958 auto expectedShape = llvm::to_vector<4>(
960 return cast<AffineConstantExpr>(e).getValue();
964 resVectorType.getScalableDims());
965 if (resVectorType != expected || accVectorType != expected)
966 return op.emitOpError(
967 "invalid accumulator/result vector shape, expected: ")
974 VectorType lhsType = getLhsType();
975 VectorType rhsType = getRhsType();
976 Type accType = getAccType();
977 Type resType = getResultType();
979 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
980 if (!lhsType.getElementType().isSignlessInteger())
981 return emitOpError(
"only supports signless integer types");
985 if (getIndexingMapsArray().size() != 3)
986 return emitOpError(
"expected an indexing map for each vector operand");
991 unsigned numIterators = getIteratorTypes().getValue().size();
993 auto index = it.index();
994 auto map = it.value();
995 if (map.getNumSymbols() != 0)
996 return emitOpError(
"expected indexing map ")
997 << index <<
" to have no symbols";
998 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
999 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1002 if (map.getNumDims() != numIterators)
1003 return emitOpError(
"expected indexing map ")
1004 << index <<
" to have " << numIterators <<
" number of inputs";
1005 if (map.getNumResults() != rank)
1006 return emitOpError(
"expected indexing map ")
1007 << index <<
" to have " << rank <<
" number of outputs";
1008 if (!map.isProjectedPermutation())
1009 return emitOpError(
"expected indexing map ")
1010 << index <<
" to be a projected permutation of its inputs";
1013 auto contractingDimMap = getContractingDimMap();
1014 auto batchDimMap = getBatchDimMap();
1017 if (contractingDimMap.empty())
1018 return emitOpError(
"expected at least one contracting dimension pair");
1021 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1022 return emitOpError(
"invalid contracting dimension map");
1026 return emitOpError(
"invalid batch dimension map");
1030 contractingDimMap, batchDimMap)))
1034 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1035 auto elementType = vectorType ? vectorType.getElementType() : resType;
1037 return emitOpError(
"unsupported contraction type");
1046 Type ContractionOp::getExpectedMaskType() {
1047 auto indexingMaps = this->getIndexingMapsArray();
1050 VectorType lhsType = this->getLhsType();
1051 VectorType rhsType = this->getRhsType();
1053 unsigned numVecDims = lhsIdxMap.
getNumDims();
1062 lhsType.getScalableDims()[dimIdx];
1067 rhsType.getScalableDims()[dimIdx];
1070 assert(!ShapedType::isDynamicShape(maskShape) &&
1071 "Mask shape couldn't be computed");
1075 maskShapeScalableDims);
1080 getIteratorTypesAttrName(), getKindAttrName()};
1090 static std::vector<std::pair<int64_t, int64_t>>
1092 IteratorType targetIteratorType,
MLIRContext *context) {
1093 std::vector<std::pair<int64_t, int64_t>> dimMap;
1095 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1096 if (iteratorType != targetIteratorType)
1102 if (lhsDim >= 0 && rhsDim >= 0)
1103 dimMap.emplace_back(lhsDim, rhsDim);
1108 void ContractionOp::getIterationBounds(
1110 auto lhsShape = getLhsType().getShape();
1111 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1117 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1118 if (iteratorType == IteratorType::reduction) {
1120 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1121 assert(lhsDimIndex >= 0);
1122 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1126 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1127 assert(resDimIndex >= 0);
1128 assert(resVectorType !=
nullptr);
1129 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1133 void ContractionOp::getIterationIndexMap(
1135 unsigned numMaps = getIndexingMapsArray().size();
1136 iterationIndexMap.resize(numMaps);
1138 auto index = it.index();
1139 auto map = it.value();
1140 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1141 auto dim = cast<AffineDimExpr>(map.getResult(i));
1142 iterationIndexMap[index][dim.getPosition()] = i;
1147 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1149 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1153 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1155 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1159 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1161 getIterationBounds(shape);
1183 template <
typename AddOpType>
1189 auto canonicalize = [&](
Value maybeContraction,
1190 Value otherOperand) -> vector::ContractionOp {
1191 vector::ContractionOp contractionOp =
1192 dyn_cast_or_null<vector::ContractionOp>(
1195 return vector::ContractionOp();
1196 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1197 contractionOp.getAcc().getDefiningOp())) {
1198 if (maybeZero.getValue() ==
1199 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1201 bvm.
map(contractionOp.getAcc(), otherOperand);
1202 auto newContraction =
1203 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1204 rewriter.
replaceOp(addOp, newContraction.getResult());
1205 return newContraction;
1208 return vector::ContractionOp();
1211 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1212 vector::ContractionOp
contract = canonicalize(a, b);
1214 return contract ? success() : failure();
1230 setResultRanges(getResult(), argRanges.front());
1236 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1240 VectorType vectorType = getSourceVectorType();
1241 if (vectorType.getRank() == 0) {
1243 return emitOpError(
"expected position to be empty with 0-D vector");
1246 if (vectorType.getRank() != 1)
1247 return emitOpError(
"unexpected >1 vector rank");
1249 return emitOpError(
"expected position for 1-D vector");
1253 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1255 if (!adaptor.getPosition())
1259 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1260 return splat.getInput();
1263 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1267 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1268 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1272 auto srcElements = src.getValues<
Attribute>();
1274 uint64_t posIdx = pos.getInt();
1275 if (posIdx >= srcElements.size())
1278 return srcElements[posIdx];
1285 return index == poisonValue || (index >= 0 && index < maxIndex);
1294 setResultRanges(getResult(), argRanges.front());
1298 Value source, int64_t position) {
1318 build(builder, result, source, dynamicPos,
1323 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1324 ExtractOp::Adaptor adaptor,
1326 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1327 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1328 vectorType.getRank()) {
1329 inferredReturnTypes.push_back(vectorType.getElementType());
1331 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1332 vectorType.getRank());
1334 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1335 vectorType.getScalableDims().drop_front(n)));
1343 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1344 return vectorType && vectorType.getShape().equals({1}) &&
1345 vectorType.getElementType() == r.front();
1347 if (l.size() == 1 && r.size() == 1 &&
1348 (isCompatible(l, r) || isCompatible(r, l)))
1355 auto dynamicMarkersCount =
1356 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1357 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1359 "mismatch between dynamic and static positions (kDynamic marker but no "
1360 "corresponding dynamic position) -- this can only happen due to an "
1361 "incorrect fold/rewrite");
1362 auto position = getMixedPosition();
1363 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1365 "expected position attribute of rank no greater than vector rank");
1367 if (
auto attr = dyn_cast<Attribute>(pos)) {
1368 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1370 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1371 return emitOpError(
"expected position attribute #")
1373 <<
" to be a non-negative integer smaller than the "
1374 "corresponding vector dimension or poison (-1)";
1381 template <
typename IntType>
1383 return llvm::to_vector<4>(llvm::map_range(
1384 arrayAttr.getAsRange<IntegerAttr>(),
1385 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1391 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1395 if (extractOp.hasDynamicPosition())
1399 ExtractOp currentOp = extractOp;
1401 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1402 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1405 if (currentOp.hasDynamicPosition())
1408 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1410 extractOp.setOperand(0, currentOp.getVector());
1413 std::reverse(globalPosition.begin(), globalPosition.end());
1414 extractOp.setStaticPosition(globalPosition);
1426 class ExtractFromInsertTransposeChainState {
1428 ExtractFromInsertTransposeChainState(ExtractOp e);
1437 template <
typename ContainerA,
typename ContainerB>
1438 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1439 return a.size() <= b.size() &&
1440 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1447 template <
typename ContainerA,
typename ContainerB>
1448 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1449 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1450 if (elemA < 0 || elemB < 0)
1465 void updateStateForNextIteration(
Value v) {
1472 LogicalResult handleTransposeOp();
1475 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1490 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1495 Value tryToFoldExtractOpInPlace(
Value source);
1497 ExtractOp extractOp;
1499 int64_t extractedRank;
1501 InsertOp nextInsertOp;
1502 TransposeOp nextTransposeOp;
1517 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1519 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1520 extractedRank(extractOp.getNumIndices()) {
1521 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1522 sentinels.reserve(vectorRank - extractedRank);
1523 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1524 sentinels.push_back(-(i + 1));
1526 extractOp.getStaticPosition().end());
1532 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1534 if (extractOp.hasDynamicPosition())
1537 if (!nextTransposeOp)
1540 nextTransposeOp.getPermutation(), extractOp.getContext()));
1547 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1550 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1557 res = nextInsertOp.getSource();
1559 return success(canFold());
1566 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1568 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1581 res = nextInsertOp.getSource();
1589 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1592 if (extractOp.hasDynamicPosition())
1596 bool nothingToFold = (source == extractOp.getVector());
1597 if (nothingToFold || !canFold())
1602 extractOp.setStaticPosition(
1604 extractOp.getVectorMutable().assign(source);
1605 return extractOp.getResult();
1609 Value ExtractFromInsertTransposeChainState::fold() {
1611 if (extractOp.hasDynamicPosition())
1614 Value valueToExtractFrom = extractOp.getVector();
1615 updateStateForNextIteration(valueToExtractFrom);
1616 while (nextInsertOp || nextTransposeOp) {
1619 if (succeeded(handleTransposeOp())) {
1620 valueToExtractFrom = nextTransposeOp.getVector();
1621 updateStateForNextIteration(valueToExtractFrom);
1627 if (succeeded(handleInsertOpWithMatchingPos(result)))
1632 if (succeeded(handleInsertOpWithPrefixPos(result)))
1633 return tryToFoldExtractOpInPlace(result);
1643 valueToExtractFrom = nextInsertOp.getDest();
1644 updateStateForNextIteration(valueToExtractFrom);
1647 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1652 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1653 auto vecType = dyn_cast<VectorType>(type);
1654 return vecType && vecType.getRank() == 0;
1664 if (extractOp.hasDynamicPosition())
1667 Operation *defOp = extractOp.getVector().getDefiningOp();
1668 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1672 if (extractOp.getType() == source.
getType())
1674 auto getRank = [](
Type type) {
1675 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1680 unsigned broadcastSrcRank = getRank(source.
getType());
1681 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1684 unsigned extractResultRank = getRank(extractOp.getType());
1685 if (extractResultRank >= broadcastSrcRank)
1688 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1689 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1690 if (extractVecType && broadcastVecType &&
1691 extractVecType.getShape() !=
1692 broadcastVecType.getShape().take_back(extractResultRank))
1695 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1696 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1702 broadcastOp.computeBroadcastedUnitDims();
1704 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1705 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1706 if (broadcastedUnitDims.contains(i))
1710 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1711 extractPos.erase(extractPos.begin(),
1712 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1715 extractOp.setOperand(0, source);
1716 extractOp.setStaticPosition(extractPos);
1717 return extractOp.getResult();
1733 if (extractOp.hasDynamicPosition())
1736 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1741 if (shuffleOp.getResultVectorType().getRank() != 1)
1744 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1745 auto shuffleMask = shuffleOp.getMask();
1746 int64_t extractIdx = extractOp.getStaticPosition()[0];
1747 int64_t shuffleIdx = shuffleMask[extractIdx];
1750 if (shuffleIdx < inputVecSize) {
1751 extractOp.setOperand(0, shuffleOp.getV1());
1752 extractOp.setStaticPosition({shuffleIdx});
1754 extractOp.setOperand(0, shuffleOp.getV2());
1755 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1758 return extractOp.getResult();
1764 if (extractOp.hasDynamicPosition())
1767 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1772 auto getDimReverse = [](VectorType type, int64_t n) {
1773 return type.getShape().take_back(n + 1).front();
1775 int64_t destinationRank =
1776 llvm::isa<VectorType>(extractOp.getType())
1777 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1779 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1781 if (destinationRank > 0) {
1782 auto destinationType =
1783 llvm::cast<VectorType>(extractOp.getResult().getType());
1784 for (int64_t i = 0; i < destinationRank; i++) {
1788 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1789 getDimReverse(destinationType, i))
1796 std::reverse(extractedPos.begin(), extractedPos.end());
1799 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1800 strides.push_back(stride);
1802 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1805 int64_t position =
linearize(extractedPos, strides);
1809 int64_t numDimension =
1810 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1812 for (int64_t i = 0; i < numDimension; i++) {
1813 newStrides.push_back(stride);
1815 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1817 std::reverse(newStrides.begin(), newStrides.end());
1821 extractOp.setStaticPosition(newPosition);
1822 extractOp.setOperand(0, shapeCastOp.getSource());
1823 return extractOp.getResult();
1829 if (extractOp.hasDynamicPosition())
1832 auto extractStridedSliceOp =
1833 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1834 if (!extractStridedSliceOp)
1843 if (extractStridedSliceOp.hasNonUnitStrides())
1848 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1849 while (!sliceOffsets.empty()) {
1850 size_t lastOffset = sliceOffsets.size() - 1;
1851 if (sliceOffsets.back() != 0 ||
1852 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1853 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1855 sliceOffsets.pop_back();
1857 unsigned destinationRank = 0;
1858 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1859 destinationRank = vecType.getRank();
1862 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1863 sliceOffsets.size())
1867 assert(extractedPos.size() >= sliceOffsets.size());
1868 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1869 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1870 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1874 extractOp.setStaticPosition(extractedPos);
1875 return extractOp.getResult();
1881 if (extractOp.hasDynamicPosition())
1884 int64_t destinationRank =
1885 llvm::isa<VectorType>(extractOp.getType())
1886 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1888 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1898 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1899 insertOp.getSourceVectorType().getRank();
1900 if (destinationRank > insertOp.getSourceVectorType().getRank())
1902 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1905 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1906 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1909 bool disjoint =
false;
1911 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1912 int64_t start = insertOffsets[dim];
1914 (dim < insertRankDiff)
1916 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1917 int64_t end = start + size;
1918 int64_t offset = extractOffsets[dim];
1920 if (start <= offset && offset < end) {
1921 if (dim >= insertRankDiff)
1922 offsetDiffs.push_back(offset - start);
1932 int64_t srcRankDiff =
1933 insertOp.getSourceVectorType().getRank() - destinationRank;
1934 for (int64_t i = 0; i < destinationRank; i++) {
1935 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1936 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1940 extractOp.getVectorMutable().assign(insertOp.getSource());
1943 extractOp.setStaticPosition(offsetDiffs);
1944 return extractOp.getResult();
1948 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1961 if (extractOp.hasDynamicPosition())
1965 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1966 if (!fromElementsOp)
1970 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1971 if (vecType.isScalable())
1975 int64_t rank = vecType.getRank();
1977 if (extractOp.getType() != vecType.getElementType())
1979 assert(
static_cast<int64_t
>(indices.size()) == rank &&
1980 "unexpected number of indices");
1985 for (
int i = rank - 1; i >= 0; --i) {
1986 flatIndex += indices[i] * stride;
1987 stride *= vecType.getDimSize(i);
1989 return fromElementsOp.getElements()[flatIndex];
1994 template <
typename OpType,
typename AdaptorType>
1997 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
1998 OperandRange dynamicPosition = op.getDynamicPosition();
2002 if (!dynamicPosition.size())
2009 bool opChange =
false;
2010 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2011 if (!ShapedType::isDynamic(staticPosition[i]))
2013 Attribute positionAttr = dynamicPositionAttr[index];
2014 Value position = dynamicPosition[index++];
2015 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2016 staticPosition[i] = attr.getInt();
2020 operands.push_back(position);
2024 op.setStaticPosition(staticPosition);
2025 op.getOperation()->setOperands(operands);
2026 return op.getResult();
2035 int64_t poisonVal) {
2036 if (!llvm::is_contained(staticPos, poisonVal))
2044 if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2054 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
2057 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2063 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2092 Operation *defOp = extractOp.getVector().getDefiningOp();
2093 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2097 if (extractOp.getType() == source.
getType())
2099 auto getRank = [](
Type type) {
2100 return llvm::isa<VectorType>(type)
2101 ? llvm::cast<VectorType>(type).getRank()
2104 unsigned broadcastSrcRank = getRank(source.
getType());
2105 unsigned extractResultRank = getRank(extractOp.getType());
2109 if (extractResultRank < broadcastSrcRank)
2113 if (extractResultRank == 0) {
2114 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
2119 extractOp, extractOp.getType(), source);
2125 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
2133 Value sourceVector = extractOp.getVector();
2137 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2140 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2141 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2149 class ExtractOpNonSplatConstantFolder final
2157 if (extractOp.hasDynamicPosition())
2162 Value sourceVector = extractOp.getVector();
2167 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
2168 if (vecTy.isScalable())
2172 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2173 if (!dense || dense.isSplat())
2179 copy(extractOp.getStaticPosition(), completePositions.begin());
2180 int64_t elemBeginPosition =
2182 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2185 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2187 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2190 newAttr = *denseValuesBegin;
2206 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2210 VectorType extractedMaskType =
2211 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2213 if (!extractedMaskType)
2216 auto maskOperands = createMaskOp.getOperands();
2218 VectorType maskType = createMaskOp.getVectorType();
2220 bool containsUnknownDims =
false;
2223 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2225 int64_t pos = extractOpPos[dimIdx];
2226 Value operand = maskOperands[dimIdx];
2227 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2230 containsUnknownDims =
true;
2234 int64_t createMaskBound =
2235 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2237 if (pos != ShapedType::kDynamic) {
2240 allFalse |= pos >= createMaskBound;
2241 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2245 containsUnknownDims =
true;
2252 }
else if (!containsUnknownDims) {
2254 extractOp, extractedMaskType,
2255 maskOperands.drop_front(extractOpPos.size()));
2265 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2267 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2271 VectorType sourceType = castOp.getSourceVectorType();
2272 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2276 if (sourceType.getNumElements() != targetType.getNumElements())
2280 castOp.getSource());
2290 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2293 if (extractOp.hasDynamicPosition())
2297 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2302 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2303 if (!fromElementsOp)
2305 VectorType inputType = fromElementsOp.getType();
2308 if (resultType.isScalable() || inputType.isScalable())
2314 llvm::to_vector(extractOp.getStaticPosition());
2315 firstElementPos.append(resultType.getRank(), 0);
2318 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2319 flatIndex += firstElementPos[i] * stride;
2320 stride *= inputType.getDimSize(i);
2325 extractOp, resultType,
2326 fromElementsOp.getElements().slice(flatIndex,
2327 resultType.getNumElements()));
2335 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2336 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2337 results.
add(foldExtractFromShapeCastToShapeCast);
2338 results.
add(foldExtractFromFromElements);
2343 for (
auto attr : arrayAttr)
2344 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2351 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2366 if (!llvm::all_equal(fromElementsOp.getElements()))
2369 fromElementsOp.getElements().front());
2384 setResultRanges(getResult(), argRanges.front());
2392 int64_t rankDiff = dstShape.size() - srcShape.size();
2393 int64_t dstDim = rankDiff;
2395 for (
auto [s1, s2] :
2396 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2398 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2408 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2427 Value BroadcastOp::createOrFoldBroadcastOp(
2430 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2434 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2435 if (broadcastedDims.contains(i))
2437 checkShape.push_back(dstShape[i]);
2439 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2440 "ill-formed broadcastedDims contains values not confined to "
2445 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2449 if (!srcVectorType) {
2450 assert(checkShape.empty() &&
2451 "ill-formed createOrFoldBroadcastOp arguments");
2452 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2455 assert(srcVectorType.getShape().equals(checkShape) &&
2456 "ill-formed createOrFoldBroadcastOp arguments");
2467 broadcastShape.reserve(dstShape.size());
2483 int64_t nextSrcShapeDim = broadcastedDims.size();
2484 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2485 if (broadcastedDims.contains(i)) {
2490 broadcastShape.push_back(dstShape[i]);
2491 permutation[i] = broadcastShape.size() - 1;
2497 permutation[i] = nextSrcShapeDim++;
2501 llvm::append_range(broadcastShape, srcVectorType.getShape());
2506 "unexpected \"dim-1\" broadcast");
2508 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2510 vector::BroadcastableToResult::Success &&
2511 "must be broadcastable");
2515 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2516 if (permutation[i] != i)
2517 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2523 Type srcType, VectorType dstVectorType,
2524 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2528 return BroadcastableToResult::Success;
2530 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2532 return BroadcastableToResult::SourceTypeNotAVector;
2534 int64_t srcRank = srcVectorType.getRank();
2535 int64_t dstRank = dstVectorType.getRank();
2536 if (srcRank > dstRank)
2537 return BroadcastableToResult::SourceRankHigher;
2540 int64_t lead = dstRank - srcRank;
2541 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2544 bool foundMismatchingDims =
false;
2547 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2548 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2549 if (srcDim != 1 && srcDim != dstDim)
2550 foundMismatchingDims =
true;
2553 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2554 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2555 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2558 (srcDimScalableFlag != dstDimScalableFlag &&
2559 (srcDim != 1 || srcDimScalableFlag)))
2560 foundMismatchingDims =
true;
2562 if (foundMismatchingDims) {
2563 if (mismatchingDims !=
nullptr) {
2564 mismatchingDims->first.dim = srcDim;
2565 mismatchingDims->first.isScalable = srcDimScalableFlag;
2567 mismatchingDims->second.dim = dstDim;
2568 mismatchingDims->second.isScalable = dstDimScalableFlag;
2570 return BroadcastableToResult::DimensionMismatch;
2574 return BroadcastableToResult::Success;
2578 std::pair<VectorDim, VectorDim> mismatchingDims;
2580 getSourceType(), getResultVectorType(), &mismatchingDims);
2581 if (res == BroadcastableToResult::Success)
2583 if (res == BroadcastableToResult::SourceRankHigher)
2584 return emitOpError(
"source rank higher than destination rank");
2585 if (res == BroadcastableToResult::DimensionMismatch) {
2586 return emitOpError(
"dimension mismatch (")
2587 << (mismatchingDims.first.isScalable ?
"[" :
"")
2588 << mismatchingDims.first.dim
2589 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2590 << (mismatchingDims.second.isScalable ?
"[" :
"")
2591 << mismatchingDims.second.dim
2592 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2594 if (res == BroadcastableToResult::SourceTypeNotAVector)
2595 return emitOpError(
"source type is not a vector");
2596 llvm_unreachable(
"unexpected vector.broadcast op error");
2600 if (getSourceType() == getResultVectorType())
2602 if (!adaptor.getSource())
2604 auto vectorType = getResultVectorType();
2605 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2606 if (vectorType.getElementType() != attr.getType())
2610 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2611 if (vectorType.getElementType() != attr.getType())
2615 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2628 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2632 broadcastOp.getResultVectorType(),
2633 srcBroadcast.getSource());
2643 results.
add<BroadcastFolder>(context);
2651 VectorType resultType = getResultVectorType();
2652 VectorType v1Type = getV1VectorType();
2653 VectorType v2Type = getV2VectorType();
2655 int64_t resRank = resultType.getRank();
2656 int64_t v1Rank = v1Type.getRank();
2657 int64_t v2Rank = v2Type.getRank();
2658 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2659 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2660 if (!wellFormed0DCase && !wellFormedNDCase)
2661 return emitOpError(
"rank mismatch");
2664 for (int64_t r = 1; r < v1Rank; ++r) {
2665 int64_t resDim = resultType.getDimSize(r);
2666 int64_t v1Dim = v1Type.getDimSize(r);
2667 int64_t v2Dim = v2Type.getDimSize(r);
2668 if (resDim != v1Dim || v1Dim != v2Dim)
2669 return emitOpError(
"dimension mismatch");
2673 int64_t maskLength = mask.size();
2674 if (maskLength <= 0)
2675 return emitOpError(
"invalid mask length");
2676 if (maskLength != resultType.getDimSize(0))
2677 return emitOpError(
"mask length mismatch");
2679 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2680 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2683 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2689 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2690 ShuffleOp::Adaptor adaptor,
2692 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2693 auto v1Rank = v1Type.getRank();
2697 shape.reserve(v1Rank);
2698 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2701 llvm::append_range(shape, v1Type.getShape().drop_front());
2702 inferredReturnTypes.push_back(
2707 template <
typename T>
2710 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2711 return value == expected++;
2715 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2716 auto v1Type = getV1VectorType();
2717 auto v2Type = getV2VectorType();
2719 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2720 "Vector shuffle does not support scalable vectors");
2724 if (v1Type.getRank() == 0)
2728 auto mask = getMask();
2735 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2736 if (!v1Attr || !v2Attr)
2740 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2741 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2742 if (isV1Poison && isV2Poison)
2747 if (v1Type.getRank() != 1)
2757 to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2758 poisonElement = v2Elements[0];
2762 to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2763 poisonElement = v1Elements[0];
2767 int64_t v1Size = v1Type.getDimSize(0);
2768 for (int64_t maskIdx : mask) {
2771 if (maskIdx == ShuffleOp::kPoisonIndex) {
2772 indexedElm = poisonElement;
2774 if (maskIdx < v1Size)
2775 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2777 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2780 results.push_back(indexedElm);
2795 VectorType v1VectorType = shuffleOp.getV1VectorType();
2797 if (v1VectorType.getRank() > 0)
2799 if (mask.size() != 1)
2819 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2820 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2822 if (!v1Splat || !v2Splat)
2825 if (v1Splat.getInput() != v2Splat.getInput())
2841 VectorType resultType = op.getResultVectorType();
2842 if (resultType.isScalable())
2844 op,
"ShuffleOp can't represent a scalable interleave");
2846 if (resultType.getRank() != 1)
2848 op,
"ShuffleOp can't represent an n-D interleave");
2850 VectorType sourceType = op.getV1VectorType();
2851 if (sourceType != op.getV2VectorType() ||
2852 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2854 op,
"ShuffleOp types don't match an interleave");
2858 int64_t resultVectorSize = resultType.getNumElements();
2859 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2860 int64_t maskValueA = shuffleMask[i * 2];
2861 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2862 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2864 "ShuffleOp mask not interleaving");
2876 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2886 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2891 build(builder, result, source, dest, {});
2895 auto dstVectorType = getDestVectorType();
2896 if (dstVectorType.getRank() == 0) {
2898 return emitOpError(
"expected position to be empty with 0-D vector");
2901 if (dstVectorType.getRank() != 1)
2902 return emitOpError(
"unexpected >1 vector rank");
2904 return emitOpError(
"expected position for 1-D vector");
2908 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2910 if (!adaptor.getPosition())
2913 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2914 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2915 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2916 if (!src || !dst || !pos)
2922 auto dstElements = dst.getValues<
Attribute>();
2926 uint64_t posIdx = pos.getInt();
2927 if (posIdx >= results.size())
2929 results[posIdx] = src;
2940 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2944 Value source,
Value dest, int64_t position) {
2957 posVals.reserve(position.size());
2958 llvm::transform(position, std::back_inserter(posVals),
2960 build(builder, result, source, dest, posVals);
2969 build(builder, result, source, dest, dynamicPos,
2975 auto destVectorType = getDestVectorType();
2976 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2978 "expected position attribute of rank no greater than dest vector rank");
2979 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2980 if (srcVectorType &&
2981 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2982 static_cast<unsigned>(destVectorType.getRank())))
2983 return emitOpError(
"expected position attribute rank + source rank to "
2984 "match dest vector rank");
2985 if (!srcVectorType &&
2986 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2988 "expected position attribute rank to match the dest vector rank");
2990 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2991 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2993 destVectorType.getDimSize(idx))) {
2994 return emitOpError(
"expected position attribute #")
2996 <<
" to be a non-negative integer smaller than the "
2998 "dest vector dimension";
3015 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
3016 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3017 srcVecType.getNumElements())
3020 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
3032 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
3033 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3035 if (!srcSplat || !dstSplat)
3038 if (srcSplat.getInput() != dstSplat.getInput())
3053 static constexpr int64_t vectorSizeFoldThreshold = 256;
3058 if (op.hasDynamicPosition())
3067 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3071 VectorType destTy = destVector.getType();
3072 if (destTy.isScalable())
3076 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3077 !destVector.hasOneUse())
3080 Value sourceValue = op.getSource();
3088 copy(op.getStaticPosition(), completePositions.begin());
3089 int64_t insertBeginPosition =
3093 Type destEltType = destTy.getElementType();
3098 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3099 for (
auto value : denseSource.getValues<
Attribute>())
3105 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3106 copy(insertedValues, allValues.begin() + insertBeginPosition);
3117 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3118 if (intAttr.getType() != expectedType)
3129 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3130 InsertOpConstantFolder>(context);
3133 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3137 if (getNumIndices() == 0 && getSourceType() ==
getType())
3143 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3168 template <
typename OpType>
3170 ArrayAttr arrayAttr,
3172 StringRef attrName) {
3173 if (arrayAttr.size() > shape.size())
3174 return op.emitOpError(
"expected ")
3175 << attrName <<
" attribute of rank no greater than vector rank";
3182 template <
typename OpType>
3183 static LogicalResult
3185 int64_t
max, StringRef attrName,
3186 bool halfOpen =
true) {
3187 for (
auto attr : arrayAttr) {
3188 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3192 if (val < min || val >= upper)
3193 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3194 <<
min <<
", " << upper <<
")";
3202 template <
typename OpType>
3203 static LogicalResult
3206 bool halfOpen =
true, int64_t
min = 0) {
3207 for (
auto [index, attrDimPair] :
3209 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3210 int64_t
max = std::get<1>(attrDimPair);
3213 if (val < min || val >=
max)
3214 return op.emitOpError(
"expected ")
3215 << attrName <<
" dimension " << index <<
" to be confined to ["
3216 <<
min <<
", " <<
max <<
")";
3226 template <
typename OpType>
3228 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3230 bool halfOpen =
true, int64_t
min = 1) {
3231 assert(arrayAttr1.size() <= shape.size());
3232 assert(arrayAttr2.size() <= shape.size());
3233 for (
auto [index, it] :
3235 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3236 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3237 int64_t
max = std::get<2>(it);
3240 if (val1 + val2 < 0 || val1 + val2 >=
max)
3241 return op.emitOpError(
"expected sum(")
3242 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3243 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3250 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3257 auto sourceVectorType = getSourceVectorType();
3258 auto destVectorType = getDestVectorType();
3259 auto offsets = getOffsetsAttr();
3260 auto strides = getStridesAttr();
3261 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3263 "expected offsets of same size as destination vector rank");
3264 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3265 return emitOpError(
"expected strides of same size as source vector rank");
3266 if (sourceVectorType.getRank() > destVectorType.getRank())
3268 "expected source rank to be no greater than destination rank");
3270 auto sourceShape = sourceVectorType.getShape();
3271 auto destShape = destVectorType.getShape();
3273 destShape.size() - sourceShape.size(), 0);
3274 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3275 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3276 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3285 offName,
"source vector shape",
3289 unsigned rankDiff = destShape.size() - sourceShape.size();
3290 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3291 if (sourceVectorType.getScalableDims()[idx] !=
3292 destVectorType.getScalableDims()[idx + rankDiff]) {
3293 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3296 if (sourceVectorType.getScalableDims()[idx]) {
3297 auto sourceSize = sourceShape[idx];
3298 auto destSize = destShape[idx + rankDiff];
3299 if (sourceSize != destSize) {
3300 return emitOpError(
"expected size at idx=")
3302 << (
" to match the corresponding base size from the input "
3304 << sourceSize << (
" vs ") << destSize << (
")");
3315 class FoldInsertStridedSliceSplat final
3320 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3323 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3325 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3327 if (!srcSplatOp || !destSplatOp)
3330 if (srcSplatOp.getInput() != destSplatOp.getInput())
3333 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3340 class FoldInsertStridedSliceOfExtract final
3345 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3347 auto extractStridedSliceOp =
3348 insertStridedSliceOp.getSource()
3349 .getDefiningOp<vector::ExtractStridedSliceOp>();
3351 if (!extractStridedSliceOp)
3354 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3358 if (extractStridedSliceOp.getStrides() !=
3359 insertStridedSliceOp.getStrides() ||
3360 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3363 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3370 class InsertStridedSliceConstantFolder final
3377 static constexpr int64_t vectorSizeFoldThreshold = 256;
3388 VectorType destTy = destVector.getType();
3389 if (destTy.isScalable())
3393 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3394 !destVector.hasOneUse())
3403 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3407 if (op.hasNonUnitStrides())
3410 VectorType sliceVecTy = sourceValue.getType();
3412 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3422 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3423 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3424 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3425 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3428 currDestPosition.begin() + rankDifference, currDestPosition.end());
3432 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3433 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3434 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3435 "Invalid slice element");
3436 newValues[linearizedPosition] = *sliceValuesIt;
3449 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3451 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3452 InsertStridedSliceConstantFolder>(context);
3455 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3456 if (getSourceVectorType() == getDestVectorType())
3473 p <<
" " << getLhs() <<
", " << getRhs();
3475 p <<
", " << getAcc();
3478 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3489 if (operandsInfo.size() < 2)
3491 "expected at least 2 operands");
3492 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3493 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3496 "expected vector type for operand #1");
3501 vRHS.getScalableDims()[0]};
3503 vLHS.getElementType(), scalableDimsRes);
3507 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3513 OuterProductOp::getKindAttrName(result.
name),
3515 OuterProductOp::getDefaultKind()));
3521 (operandsInfo.size() > 2 &&
3527 Type tRHS = getOperandTypeRHS();
3528 VectorType vLHS = getOperandVectorTypeLHS(),
3529 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3530 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3532 if (vLHS.getRank() != 1)
3533 return emitOpError(
"expected 1-d vector for operand #1");
3537 if (vRHS.getRank() != 1)
3538 return emitOpError(
"expected 1-d vector for operand #2");
3539 if (vRES.getRank() != 2)
3540 return emitOpError(
"expected 2-d vector result");
3541 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3542 return emitOpError(
"expected #1 operand dim to match result dim #1");
3543 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3544 return emitOpError(
"expected #2 operand dim to match result dim #2");
3545 if (vLHS.isScalable() && !vRHS.isScalable()) {
3549 "expected either both or only #2 operand dim to be scalable");
3553 if (vRES.getRank() != 1)
3554 return emitOpError(
"expected 1-d vector result");
3555 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3556 return emitOpError(
"expected #1 operand dim to match result dim #1");
3559 if (vACC && vACC != vRES)
3560 return emitOpError(
"expected operand #3 of same type as result type");
3564 return emitOpError(
"unsupported outerproduct type");
3573 Type OuterProductOp::getExpectedMaskType() {
3574 auto vecType = this->getResultVectorType();
3577 vecType.getScalableDims());
3589 ArrayAttr offsets, ArrayAttr sizes,
3590 ArrayAttr strides) {
3591 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3593 shape.reserve(vectorType.getRank());
3595 for (
unsigned e = offsets.size(); idx < e; ++idx)
3596 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3597 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3598 shape.push_back(vectorType.getShape()[idx]);
3601 vectorType.getScalableDims());
3614 offsetsAttr, sizesAttr, stridesAttr));
3615 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3619 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3624 auto type = getSourceVectorType();
3625 auto offsets = getOffsetsAttr();
3626 auto sizes = getSizesAttr();
3627 auto strides = getStridesAttr();
3628 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3630 "expected offsets, sizes and strides attributes of same size");
3632 auto shape = type.getShape();
3633 auto offName = getOffsetsAttrName();
3634 auto sizesName = getSizesAttrName();
3635 auto stridesName = getStridesAttrName();
3651 shape, offName, sizesName,
3656 offsets, sizes, strides);
3657 if (getResult().
getType() != resultType)
3658 return emitOpError(
"expected result type to be ") << resultType;
3660 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3661 if (type.getScalableDims()[idx]) {
3662 auto inputDim = type.getShape()[idx];
3663 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3664 if (inputDim != inputSize)
3665 return emitOpError(
"expected size at idx=")
3667 << (
" to match the corresponding base size from the input "
3669 << inputSize << (
" vs ") << inputDim << (
")");
3679 static LogicalResult
3682 auto getElement = [](ArrayAttr array,
int idx) {
3683 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3685 ArrayAttr extractOffsets = op.getOffsets();
3687 ArrayAttr extractSizes = op.getSizes();
3688 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3690 if (op.getSourceVectorType().getRank() !=
3691 insertOp.getSourceVectorType().getRank())
3693 ArrayAttr insertOffsets = insertOp.getOffsets();
3694 ArrayAttr insertStrides = insertOp.getStrides();
3697 if (extractOffsets.size() > insertOffsets.size())
3699 bool patialoverlap =
false;
3700 bool disjoint =
false;
3702 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3703 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3705 int64_t start = getElement(insertOffsets, dim);
3706 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3707 int64_t offset = getElement(extractOffsets, dim);
3708 int64_t size = getElement(extractSizes, dim);
3710 if (start <= offset && offset < end) {
3713 if (offset + size > end)
3714 patialoverlap =
true;
3715 offsetDiffs.push_back(offset - start);
3722 if (!disjoint && !patialoverlap) {
3723 op.setOperand(insertOp.getSource());
3732 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3742 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3743 if (getSourceVectorType() == getResult().
getType())
3758 class StridedSliceConstantMaskFolder final
3763 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3767 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3768 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3769 if (!constantMaskOp)
3772 if (extractStridedSliceOp.hasNonUnitStrides())
3785 sliceMaskDimSizes.reserve(maskDimSizes.size());
3786 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3787 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3788 int64_t sliceMaskDimSize =
std::max(
3789 static_cast<int64_t
>(0),
3790 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3791 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3794 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3795 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3796 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3799 if (llvm::is_contained(sliceMaskDimSizes, 0))
3800 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3805 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3812 class StridedSliceSplatConstantFolder final
3817 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3821 Value sourceVector = extractStridedSliceOp.getVector();
3826 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3840 class StridedSliceNonSplatConstantFolder final
3845 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3849 Value sourceVector = extractStridedSliceOp.getVector();
3855 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3856 if (!dense || dense.isSplat())
3860 if (extractStridedSliceOp.hasNonUnitStrides())
3863 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3867 VectorType sliceVecTy = extractStridedSliceOp.getType();
3869 int64_t sliceRank = sliceVecTy.getRank();
3881 auto denseValuesBegin = dense.value_begin<
Attribute>();
3883 sliceValues.reserve(sliceVecTy.getNumElements());
3886 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3887 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3889 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3893 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3894 sliceVecTy.getNumElements() &&
3895 "Invalid number of slice elements");
3905 class StridedSliceBroadcast final
3917 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3918 auto dstVecType = llvm::cast<VectorType>(op.getType());
3919 unsigned dstRank = dstVecType.getRank();
3920 unsigned rankDiff = dstRank - srcRank;
3924 bool lowerDimMatch =
true;
3925 for (
unsigned i = 0; i < srcRank; i++) {
3926 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3927 lowerDimMatch =
false;
3936 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3937 if (!lowerDimMatch && !isScalarSrc) {
3938 source = rewriter.
create<ExtractStridedSliceOp>(
3939 op->getLoc(), source,
3950 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3956 auto splat = op.getVector().getDefiningOp<SplatOp>();
3980 class ContiguousExtractStridedSliceToExtract final
3987 if (op.hasNonUnitStrides())
3989 Value source = op.getOperand();
3990 auto sourceType = cast<VectorType>(source.
getType());
3991 if (sourceType.isScalable() || sourceType.getRank() == 0)
4000 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4001 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4008 if (numOffsets == 0)
4013 if (numOffsets == sourceType.getRank() &&
4014 static_cast<int>(sizes.size()) == sourceType.getRank())
4018 for (
int i = 0; i < numOffsets; ++i) {
4026 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4027 sizes[numOffsets] == 1) {
4032 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4033 Value extract = rewriter.
create<vector::ExtractOp>(op->getLoc(), source,
4042 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4046 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4047 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
4048 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4058 VectorType vectorType,
Value source,
4059 ValueRange indices, AffineMapAttr permutationMapAttr,
4060 ArrayAttr inBoundsAttr) {
4061 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4062 Value padding = builder.
create<arith::ConstantOp>(
4064 build(builder, result, vectorType, source, indices, permutationMapAttr,
4065 padding,
Value(), inBoundsAttr);
4070 VectorType vectorType,
Value source,
4074 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4078 build(builder, result, vectorType, source, indices, permutationMapAttr,
4084 VectorType vectorType,
Value source,
4088 llvm::cast<ShapedType>(source.
getType()), vectorType);
4090 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4094 build(builder, result, vectorType, source, indices, permutationMapAttr,
4096 Value(), inBoundsAttr);
4102 VectorType vectorType,
Value source,
4105 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4106 Value padding = builder.
create<arith::ConstantOp>(
4108 build(builder, result, vectorType, source, indices, padding, inBounds);
4111 template <
typename EmitFun>
4113 EmitFun emitOpError) {
4115 for (
auto expr : permutationMap.
getResults()) {
4116 auto dim = dyn_cast<AffineDimExpr>(expr);
4117 auto zero = dyn_cast<AffineConstantExpr>(expr);
4119 if (zero.getValue() != 0) {
4121 "requires a projected permutation_map (at most one dim or the zero "
4122 "constant can appear in each result)");
4127 return emitOpError(
"requires a projected permutation_map (at most one "
4128 "dim or the zero constant can appear in each result)");
4130 if (seen[dim.getPosition()]) {
4132 "requires a permutation_map that is a permutation (found one dim "
4133 "used more than once)");
4135 seen[dim.getPosition()] =
true;
4140 static LogicalResult
4142 VectorType vectorType, VectorType maskType,
4143 VectorType inferredMaskType,
AffineMap permutationMap,
4144 ArrayAttr inBounds) {
4145 if (op->hasAttr(
"masked")) {
4146 return op->emitOpError(
"masked attribute has been removed. "
4147 "Use in_bounds instead.");
4150 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4151 return op->emitOpError(
4152 "requires source to be a memref or ranked tensor type");
4154 auto elementType = shapedType.getElementType();
4155 DataLayout dataLayout = DataLayout::closest(op);
4156 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4158 unsigned sourceVecSize =
4160 vectorElementType.getShape().back();
4161 unsigned resultVecSize =
4163 vectorType.getShape().back();
4164 if (resultVecSize % sourceVecSize != 0)
4165 return op->emitOpError(
4166 "requires the bitwidth of the minor 1-D vector to be an integral "
4167 "multiple of the bitwidth of the minor 1-D vector of the source");
4169 unsigned sourceVecEltRank = vectorElementType.getRank();
4170 unsigned resultVecRank = vectorType.getRank();
4171 if (sourceVecEltRank > resultVecRank)
4172 return op->emitOpError(
4173 "requires source vector element and vector result ranks to match.");
4174 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4177 return op->emitOpError(
"requires a permutation_map with result dims of "
4178 "the same rank as the vector type");
4181 return op->emitOpError(
"does not support masks with vector element type");
4184 unsigned minorSize =
4185 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4186 unsigned resultVecSize =
4189 return op->emitOpError(
4190 "requires the bitwidth of the minor 1-D vector to be an integral "
4191 "multiple of the bitwidth of the source element type");
4195 return op->emitOpError(
"requires a permutation_map with result dims of "
4196 "the same rank as the vector type");
4200 return op->emitOpError(
"requires permutation_map without symbols");
4202 if (permutationMap.
getNumInputs() != shapedType.getRank())
4203 return op->emitOpError(
"requires a permutation_map with input dims of the "
4204 "same rank as the source type");
4206 if (maskType && maskType != inferredMaskType)
4207 return op->emitOpError(
"inferred mask type (")
4208 << inferredMaskType <<
") and mask operand type (" << maskType
4211 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4212 return op->emitOpError(
"expects the in_bounds attr of same rank "
4213 "as permutation_map results: ")
4215 <<
" vs inBounds of size: " << inBounds.size();
4222 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4223 if (op.getPermutationMap().isMinorIdentity())
4224 elidedAttrs.push_back(op.getPermutationMapAttrName());
4226 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4227 elidedAttrs.push_back(op.getInBoundsAttrName());
4232 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
4234 p <<
", " << getMask();
4243 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4248 if (maskShape.empty())
4249 maskShape.push_back(1);
4271 if (hasMask.succeeded()) {
4278 if (types.size() != 2)
4279 return parser.
emitError(typesLoc,
"requires two types");
4281 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4282 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4283 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4284 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4286 return parser.
emitError(typesLoc,
"requires vector type");
4287 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4294 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4296 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4298 if (!inBoundsAttr) {
4308 if (hasMask.succeeded()) {
4309 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4311 maskInfo.
location,
"does not support masks with vector element type");
4314 "expected the same rank for the vector and the "
4315 "results of the permutation map");
4323 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4325 {1, static_cast<int32_t>(indexInfo.size()), 1,
4326 static_cast<int32_t>(hasMask.succeeded())}));
4332 ShapedType shapedType = getShapedType();
4334 VectorType maskType = getMaskType();
4335 auto paddingType = getPadding().getType();
4336 auto permutationMap = getPermutationMap();
4337 VectorType inferredMaskType =
4340 auto sourceElementType = shapedType.getElementType();
4342 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4343 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4345 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4346 shapedType, vectorType, maskType,
4347 inferredMaskType, permutationMap, getInBounds())))
4350 if (
auto sourceVectorElementType =
4351 llvm::dyn_cast<VectorType>(sourceElementType)) {
4354 if (sourceVectorElementType != paddingType)
4356 "requires source element type and padding type to match.");
4360 if (!VectorType::isValidElementType(paddingType))
4361 return emitOpError(
"requires valid padding vector elemental type");
4364 if (paddingType != sourceElementType)
4366 "requires formal padding and source of the same elemental type");
4370 [&](Twine t) {
return emitOpError(t); });
4377 Type TransferReadOp::getExpectedMaskType() {
4381 template <
typename TransferOp>
4382 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4385 if (op.getShapedType().isDynamicDim(indicesIdx))
4387 Value index = op.getIndices()[indicesIdx];
4389 if (!cstOp.has_value())
4392 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4393 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4395 return cstOp.value() + vectorSize <= sourceSize;
4398 template <
typename TransferOp>
4402 if (op.getTransferRank() == 0)
4407 newInBounds.reserve(op.getTransferRank());
4412 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4414 if (op.isDimInBounds(i)) {
4415 newInBounds.push_back(
true);
4420 bool inBounds =
false;
4421 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4424 dimExpr.getPosition());
4425 nonBcastDims.push_back(i);
4428 newInBounds.push_back(inBounds);
4436 bool allNonBcastDimsInBounds = llvm::all_of(
4437 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4438 if (allNonBcastDimsInBounds) {
4441 newInBounds[idx] =
true;
4453 template <
typename TransferOp>
4455 auto mask = op.getMask();
4462 op.getMaskMutable().clear();
4476 static Value foldRAW(TransferReadOp readOp) {
4477 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4479 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4482 return defWrite.getVector();
4484 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4485 cast<VectorTransferOpInterface>(readOp.getOperation())))
4487 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4493 if (
Value vec = foldRAW(*
this))
4507 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4511 void TransferReadOp::getEffects(
4514 if (llvm::isa<MemRefType>(getShapedType()))
4520 if (hasPureTensorSemantics())
4548 struct TransferReadAfterWriteToBroadcast
4554 if (readOp.hasOutOfBoundsDim() ||
4555 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4557 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4562 if (readOp.getTransferChunkAccessed() !=
4563 defWrite.getTransferChunkAccessed())
4570 if (readOp.getIndices() != defWrite.getIndices() ||
4571 readOp.getMask() != defWrite.getMask())
4573 Value vec = defWrite.getVector();
4595 broadcastShape[pos.value()] = destShape[pos.index()];
4596 broadcastScalableFlags[pos.value()] =
4597 readOp.getVectorType().getScalableDims()[pos.index()];
4600 broadcastShape, defWrite.getVectorType().getElementType(),
4601 broadcastScalableFlags);
4602 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4613 results.
add<TransferReadAfterWriteToBroadcast>(context);
4623 AffineMapAttr permutationMapAttr,
4625 ArrayAttr inBoundsAttr) {
4626 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4627 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4628 mask, inBoundsAttr);
4634 AffineMapAttr permutationMapAttr,
4635 ArrayAttr inBoundsAttr) {
4636 build(builder, result, vector, dest, indices, permutationMapAttr,
4637 Value(), inBoundsAttr);
4648 (inBounds && !inBounds.value().empty())
4651 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4652 build(builder, result, vector, dest, indices, permutationMapAttr,
4653 Value(), inBoundsAttr);
4661 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4663 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4664 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4680 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4685 if (types.size() != 2)
4686 return parser.
emitError(typesLoc,
"requires two types");
4688 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4690 return parser.
emitError(typesLoc,
"requires vector type");
4691 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4692 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4693 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4694 auto permMapAttrName =
4695 TransferWriteOp::getPermutationMapAttrName(result.
name);
4702 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4704 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4706 if (!inBoundsAttr) {
4715 if (hasMask.succeeded()) {
4716 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4718 maskInfo.
location,
"does not support masks with vector element type");
4721 "expected the same rank for the vector and the "
4722 "results of the permutation map");
4728 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4730 {1, 1, static_cast<int32_t>(indexInfo.size()),
4731 static_cast<int32_t>(hasMask.succeeded())}));
4732 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4737 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4739 p <<
", " << getMask();
4746 ShapedType shapedType = getShapedType();
4748 VectorType maskType = getMaskType();
4749 auto permutationMap = getPermutationMap();
4750 VectorType inferredMaskType =
4754 if (llvm::size(
getIndices()) != shapedType.getRank())
4755 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4759 if (hasBroadcastDim())
4760 return emitOpError(
"should not have broadcast dimensions");
4762 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4763 shapedType, vectorType, maskType,
4764 inferredMaskType, permutationMap, getInBounds())))
4768 [&](Twine t) {
return emitOpError(t); });
4775 Type TransferWriteOp::getExpectedMaskType() {
4796 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4800 if (write.getTransferRank() == 0)
4802 auto rankedTensorType =
4803 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4805 if (!rankedTensorType)
4808 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4812 if (read.getTransferRank() == 0)
4815 if (!read.getPermutationMap().isMinorIdentity() ||
4816 !write.getPermutationMap().isMinorIdentity())
4819 if (read.getTransferRank() != write.getTransferRank())
4822 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4825 if (read.getSource().getType() != rankedTensorType)
4828 if (read.getVectorType() != write.getVectorType())
4831 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4834 auto isNotConstantZero = [](
Value v) {
4836 return !cstOp.has_value() || cstOp.value() != 0;
4838 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4839 llvm::any_of(write.getIndices(), isNotConstantZero))
4842 results.push_back(read.getSource());
4846 static bool checkSameValueWAR(vector::TransferReadOp read,
4847 vector::TransferWriteOp write) {
4848 return read.getSource() == write.getSource() &&
4849 read.getIndices() == write.getIndices() &&
4850 read.getPermutationMap() == write.getPermutationMap() &&
4851 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4868 static LogicalResult foldWAR(TransferWriteOp write,
4870 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4872 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4876 if (!checkSameValueWAR(read, write))
4878 results.push_back(read.getSource());
4882 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4884 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4886 if (succeeded(foldWAR(*
this, results)))
4895 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4899 void TransferWriteOp::getEffects(
4902 if (llvm::isa<MemRefType>(getShapedType()))
4908 if (hasPureTensorSemantics())
4943 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4945 vector::TransferWriteOp writeToModify = writeOp;
4948 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4952 writeToModify.getSourceMutable().assign(defWrite.getSource());
4957 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4958 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4962 if (!defWrite->hasOneUse())
4964 writeToModify = defWrite;
4965 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4994 struct SwapExtractSliceOfTransferWrite
5001 if (!insertOp.hasUnitStride())
5004 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5005 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5007 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5008 if (!transferOp || !transferOp->hasOneUse())
5013 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5015 "use-def chain is rank-reducing");
5019 if (!extractOp.hasZeroOffset()) {
5021 "ExtractSliceOp has non-zero offset");
5025 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5029 "TranferWriteOp has non-zero offset");
5033 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5035 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5038 for (
auto [insertSize, extractSize] :
5039 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5042 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5047 assert(transferOp.getVectorType().hasStaticShape() &&
5048 "expected vector to have a static shape");
5051 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5052 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5054 insertOp,
"TransferWriteOp may not write the full tensor.");
5060 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
5061 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5062 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5063 insertOp.getMixedStrides());
5064 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
5065 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5066 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5069 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5079 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5086 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5088 MemRefType memRefTy) {
5091 if (!vecTy.isScalable() &&
5092 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5095 if (!memRefTy.isLastDimUnitStride())
5096 return op->
emitOpError(
"most minor memref dim must have unit stride");
5104 if (failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5108 Type memElemTy = memRefTy.getElementType();
5109 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5110 if (memVecTy != resVecTy)
5111 return emitOpError(
"base memref and result vector types should match");
5112 memElemTy = memVecTy.getElementType();
5115 if (resVecTy.getElementType() != memElemTy)
5116 return emitOpError(
"base and result element types should match");
5117 if (llvm::size(
getIndices()) != memRefTy.getRank())
5118 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5136 if (failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5140 Type memElemTy = memRefTy.getElementType();
5141 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5142 if (memVecTy != valueVecTy)
5144 "base memref and valueToStore vector types should match");
5145 memElemTy = memVecTy.getElementType();
5148 if (valueVecTy.getElementType() != memElemTy)
5149 return emitOpError(
"base and valueToStore element type should match");
5150 if (llvm::size(
getIndices()) != memRefTy.getRank())
5151 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5155 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5165 VectorType maskVType = getMaskVectorType();
5166 VectorType passVType = getPassThruVectorType();
5170 if (resVType.getElementType() != memType.getElementType())
5171 return emitOpError(
"base and result element type should match");
5172 if (llvm::size(
getIndices()) != memType.getRank())
5173 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5174 if (resVType.getShape() != maskVType.getShape())
5175 return emitOpError(
"expected result shape to match mask shape");
5176 if (resVType != passVType)
5177 return emitOpError(
"expected pass_thru of same type as result type");
5190 load, load.getType(), load.getBase(), load.getIndices());
5193 rewriter.
replaceOp(load, load.getPassThru());
5198 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5205 results.
add<MaskedLoadFolder>(context);
5219 VectorType maskVType = getMaskVectorType();
5223 if (valueVType.getElementType() != memType.getElementType())
5224 return emitOpError(
"base and valueToStore element type should match");
5225 if (llvm::size(
getIndices()) != memType.getRank())
5226 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5227 if (valueVType.getShape() != maskVType.getShape())
5228 return emitOpError(
"expected valueToStore shape to match mask shape");
5241 store, store.getValueToStore(), store.getBase(), store.getIndices());
5249 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5256 results.
add<MaskedStoreFolder>(context);
5259 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5269 VectorType indVType = getIndexVectorType();
5270 VectorType maskVType = getMaskVectorType();
5272 ShapedType baseType = getBaseType();
5274 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5275 return emitOpError(
"requires base to be a memref or ranked tensor type");
5277 if (resVType.getElementType() != baseType.getElementType())
5278 return emitOpError(
"base and result element type should match");
5279 if (llvm::size(
getIndices()) != baseType.getRank())
5280 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5281 if (resVType.getShape() != indVType.getShape())
5282 return emitOpError(
"expected result dim to match indices dim");
5283 if (resVType.getShape() != maskVType.getShape())
5284 return emitOpError(
"expected result dim to match mask dim");
5285 if (resVType != getPassThruVectorType())
5286 return emitOpError(
"expected pass_thru of same type as result type");
5294 Type GatherOp::getExpectedMaskType() {
5295 auto vecType = this->getIndexVectorType();
5298 vecType.getScalableDims());
5301 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5306 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5307 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5308 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5319 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5332 rewriter.
replaceOp(gather, gather.getPassThru());
5337 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5348 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5352 op.getIndices(), op.getMask(),
5361 results.
add<GatherFolder, FoldContiguousGather>(context);
5369 VectorType indVType = getIndexVectorType();
5370 VectorType maskVType = getMaskVectorType();
5374 if (valueVType.getElementType() != memType.getElementType())
5375 return emitOpError(
"base and valueToStore element type should match");
5376 if (llvm::size(
getIndices()) != memType.getRank())
5377 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5378 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5379 return emitOpError(
"expected valueToStore dim to match indices dim");
5380 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5381 return emitOpError(
"expected valueToStore dim to match mask dim");
5400 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5411 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5415 op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5423 results.
add<ScatterFolder, FoldContiguousScatter>(context);
5431 VectorType maskVType = getMaskVectorType();
5432 VectorType passVType = getPassThruVectorType();
5436 if (resVType.getElementType() != memType.getElementType())
5437 return emitOpError(
"base and result element type should match");
5438 if (llvm::size(
getIndices()) != memType.getRank())
5439 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5440 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5441 return emitOpError(
"expected result dim to match mask dim");
5442 if (resVType != passVType)
5443 return emitOpError(
"expected pass_thru of same type as result type");
5456 expand, expand.getType(), expand.getBase(), expand.getIndices());
5459 rewriter.
replaceOp(expand, expand.getPassThru());
5464 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5471 results.
add<ExpandLoadFolder>(context);
5479 VectorType maskVType = getMaskVectorType();
5483 if (valueVType.getElementType() != memType.getElementType())
5484 return emitOpError(
"base and valueToStore element type should match");
5485 if (llvm::size(
getIndices()) != memType.getRank())
5486 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5487 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5488 return emitOpError(
"expected valueToStore dim to match mask dim");
5493 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5501 compress, compress.getValueToStore(), compress.getBase(),
5502 compress.getIndices());
5510 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5517 results.
add<CompressStoreFolder>(context);
5526 setResultRanges(getResult(), argRanges.front());
5532 unsigned rankA = a.size();
5533 unsigned rankB = b.size();
5534 assert(rankA < rankB);
5536 auto isOne = [](int64_t v) {
return v == 1; };
5540 if (rankA == 0 && llvm::all_of(b, isOne))
5545 while (i < rankA &&
j < rankB) {
5546 int64_t dimA = a[i];
5548 while (dimB < dimA &&
j < rankB)
5556 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5558 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5562 return i == rankA &&
j == rankB;
5565 static LogicalResult verifyVectorShapeCast(
Operation *op,
5566 VectorType sourceVectorType,
5567 VectorType resultVectorType) {
5569 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5570 return op->
emitOpError(
"source/result vectors must have same element type");
5571 auto sourceShape = sourceVectorType.getShape();
5572 auto resultShape = resultVectorType.getShape();
5575 int64_t sourceDimProduct = std::accumulate(
5576 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5577 int64_t resultDimProduct = std::accumulate(
5578 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5579 if (sourceDimProduct != resultDimProduct)
5580 return op->
emitOpError(
"source/result number of elements must match");
5583 unsigned sourceRank = sourceVectorType.getRank();
5584 unsigned resultRank = resultVectorType.getRank();
5585 if (sourceRank < resultRank) {
5586 if (!isValidShapeCast(sourceShape, resultShape))
5588 }
else if (sourceRank > resultRank) {
5589 if (!isValidShapeCast(resultShape, sourceShape))
5594 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5595 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5596 if (sourceNScalableDims != resultNScalableDims)
5597 return op->
emitOpError(
"different number of scalable dims at source (")
5598 << sourceNScalableDims <<
") and result (" << resultNScalableDims
5600 sourceVectorType.getNumDynamicDims();
5606 auto sourceVectorType =
5607 llvm::dyn_cast_or_null<VectorType>(getSource().
getType());
5608 auto resultVectorType =
5609 llvm::dyn_cast_or_null<VectorType>(getResult().
getType());
5612 if (sourceVectorType && resultVectorType)
5613 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5624 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5625 if (getResult().
getType() == otherOp.getSource().getType())
5626 return otherOp.getSource();
5629 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5630 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
5631 if (srcType.getRank() < resultType.getRank()) {
5632 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5634 }
else if (srcType.getRank() > resultType.getRank()) {
5635 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5641 setOperand(otherOp.getSource());
5646 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5647 if (bcastOp.getSourceType() ==
getType())
5648 return bcastOp.getSource();
5656 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5663 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5667 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5683 static VectorType trimTrailingOneDims(VectorType oldType) {
5690 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5691 newShape = newShape.drop_back(1);
5692 newScalableDims = newScalableDims.drop_back(1);
5697 if (newShape.empty()) {
5698 newShape = oldShape.take_back();
5699 newScalableDims = oldScalableDims.take_back();
5702 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5717 class ShapeCastCreateMaskFolderTrailingOneDim final
5724 Value shapeOpSrc = shapeOp->getOperand(0);
5725 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5726 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5727 if (!createMaskOp && !constantMaskOp)
5730 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5731 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5733 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5734 if (newVecType != shapeOpResTy)
5737 auto numDimsToDrop =
5738 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5745 auto maskOperands = createMaskOp.getOperands();
5746 auto numMaskOperands = maskOperands.size();
5749 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5751 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5752 if (!constant || (constant.value() != 1))
5756 maskOperands.drop_back(numDimsToDrop);
5763 if (constantMaskOp) {
5764 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5765 auto numMaskOperands = maskDimSizes.size();
5768 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5770 if (maskDimSizes[i] != 1)
5774 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5789 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5796 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5801 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5802 broadcastSourceShape = srcType.getShape();
5804 shapeCastOp.getResultVectorType().getShape();
5808 if (broadcastSourceShape ==
5809 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5811 shapeCastOp, shapeCastOp.getResultVectorType(),
5812 broadcastOp.getSource());
5818 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5819 if (srcType.getNumElements() ==
5820 shapeCastOp.getResultVectorType().getNumElements()) {
5822 shapeCastOp, shapeCastOp.getResultVectorType(),
5823 broadcastOp.getSource());
5836 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5837 ShapeCastBroadcastFolder>(context);
5845 auto sourceVectorType = getSourceVectorType();
5846 auto resultVectorType = getResultVectorType();
5848 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5849 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5850 return emitOpError(
"dimension size mismatch at: ") << i;
5853 DataLayout dataLayout = DataLayout::closest(*
this);
5854 auto sourceElementBits =
5856 auto resultElementBits =
5859 if (sourceVectorType.getRank() == 0) {
5860 if (sourceElementBits != resultElementBits)
5861 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5862 "types must be equal");
5863 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5864 resultElementBits * resultVectorType.getShape().back()) {
5866 "source/result bitwidth of the minor 1-D vectors must be equal");
5878 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5879 if (getResult().
getType() == otherOp.getSource().getType())
5880 return otherOp.getSource();
5882 setOperand(otherOp.getSource());
5886 Attribute sourceConstant = adaptor.getSource();
5887 if (!sourceConstant)
5890 Type srcElemType = getSourceVectorType().getElementType();
5891 Type dstElemType = getResultVectorType().getElementType();
5893 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5894 if (floatPack.isSplat()) {
5895 auto splat = floatPack.getSplatValue<FloatAttr>();
5898 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5899 uint32_t bits =
static_cast<uint32_t
>(
5900 splat.getValue().bitcastToAPInt().getZExtValue());
5902 bits = (bits << 16) | (bits & 0xffff);
5903 APInt intBits(32, bits);
5904 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5910 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5911 if (intPack.isSplat()) {
5912 auto splat = intPack.getSplatValue<IntegerAttr>();
5914 if (llvm::isa<IntegerType>(dstElemType)) {
5919 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5920 APInt intBits = splat.getValue().zext(dstBitWidth);
5923 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5924 intBits = (intBits << srcBitWidth) | intBits;
5939 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5942 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5951 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5952 VectorType vectorType =
5956 memRefType.getMemorySpace()));
5960 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
5961 if (!canonicalType.getLayout().isIdentity())
5962 return emitOpError(
"expects operand to be a memref with identity layout");
5963 if (!getResultMemRefType().getLayout().isIdentity())
5964 return emitOpError(
"expects result to be a memref with identity layout");
5965 if (getResultMemRefType().getMemorySpace() !=
5967 return emitOpError(
"expects result in same memory space");
5970 auto resultType = getResultMemRefType();
5974 "expects result and operand with same underlying scalar type: ")
5976 if (extractShape(sourceType) != extractShape(resultType))
5978 "expects concatenated result and operand shapes to be equal: ")
5989 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5992 for (
unsigned i = 0; i < permutation.size(); ++i) {
5993 transposedShape[i] = vt.getShape()[permutation[i]];
5994 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5999 transposedScalableDims));
6004 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6007 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
6009 return attr.reshape(getResultVectorType());
6017 for (int64_t i = 0, e = perm.size(); i < e; i++) {
6026 VectorType vectorType = getSourceVectorType();
6027 VectorType resultType = getResultVectorType();
6028 int64_t rank = resultType.getRank();
6029 if (vectorType.getRank() != rank)
6030 return emitOpError(
"vector result rank mismatch: ") << rank;
6033 int64_t size = perm.size();
6035 return emitOpError(
"transposition length mismatch: ") << size;
6038 if (ta.value() < 0 || ta.value() >= rank)
6039 return emitOpError(
"transposition index out of range: ") << ta.value();
6040 if (seen[ta.value()])
6041 return emitOpError(
"duplicate position index: ") << ta.value();
6042 seen[ta.value()] =
true;
6043 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6044 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6049 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6050 return llvm::to_vector<4>(getResultVectorType().
getShape());
6056 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6066 for (
auto index : permutation2)
6067 result.push_back(permutation1[index]);
6072 vector::TransposeOp parentTransposeOp =
6073 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6074 if (!parentTransposeOp)
6078 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6081 transposeOp, transposeOp.getResult().getType(),
6082 parentTransposeOp.getVector(), permutation);
6088 struct FoldTransposedScalarBroadcast final
6094 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6098 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6099 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6101 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6116 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6121 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6127 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6133 Value transposeSrc = transpOp.getVector();
6134 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6135 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6136 if (!createMaskOp && !constantMaskOp)
6144 auto maskOperands = createMaskOp.getOperands();
6149 transpOp, transpOp.getResultVectorType(), newOperands);
6154 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6158 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6165 void vector::TransposeOp::getCanonicalizationPatterns(
6167 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6168 TransposeFolder, FoldTransposeSplat>(context);
6177 assert(kind == ConstantMaskKind::AllTrue ||
6178 kind == ConstantMaskKind::AllFalse);
6179 build(builder, result, type,
6180 kind == ConstantMaskKind::AllTrue
6186 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6188 if (resultType.getRank() == 0) {
6189 if (getMaskDimSizes().size() != 1)
6190 return emitError(
"array attr must have length 1 for 0-D vectors");
6191 auto dim = getMaskDimSizes()[0];
6192 if (dim != 0 && dim != 1)
6193 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6198 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6200 "must specify array attr of size equal vector result rank");
6203 auto resultShape = resultType.getShape();
6204 auto resultScalableDims = resultType.getScalableDims();
6206 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6207 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6209 "array attr of size out of bounds of vector result dimension size");
6210 if (resultScalableDims[index] && maskDimSize != 0 &&
6211 maskDimSize != resultShape[index])
6213 "only supports 'none set' or 'all set' scalable dimensions");
6217 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6218 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6219 if (anyZeros && !allZeros)
6220 return emitOpError(
"expected all mask dim sizes to be zeros, "
6221 "as a result of conjunction with zero mask dim");
6225 bool ConstantMaskOp::isAllOnesMask() {
6228 if (resultType.getRank() == 0) {
6229 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6230 return getMaskDimSizes()[0] == 1;
6232 for (
const auto [resultSize, maskDimSize] :
6233 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6234 if (maskDimSize < resultSize)
6249 build(builder, result, type, operands);
6253 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6255 if (vectorType.getRank() == 0) {
6256 if (getNumOperands() != 1)
6258 "must specify exactly one operand for 0-D create_mask");
6259 }
else if (getNumOperands() !=
6260 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6262 "must specify an operand for each result vector dimension");
6298 VectorType maskType = createMaskOp.getVectorType();
6300 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6303 constexpr std::array<int64_t, 1> rankZeroShape{1};
6304 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6305 if (maskType.getRank() == 0) {
6306 maskTypeDimSizes = rankZeroShape;
6307 maskTypeDimScalableFlags = rankZeroScalableDims;
6313 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6318 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6320 constantDims.push_back(*intSize);
6324 if (vscaleMultiplier < maskTypeDimSizes[i])
6326 constantDims.push_back(*vscaleMultiplier);
6333 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6334 value = std::clamp<int64_t>(value, 0, maskDimSize);
6337 if (llvm::is_contained(constantDims, 0))
6338 constantDims.assign(constantDims.size(), 0);
6351 results.
add<CreateMaskFolder>(context);
6362 assert(maskRegionBuilder &&
6363 "builder callback for 'maskRegion' must be present");
6369 maskRegionBuilder(builder, maskableOp);
6376 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6384 build(builder, result, mask, maskableOp, maskRegionBuilder);
6405 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6412 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6426 result.
types.append(resultTypes);
6432 if (parsePassthru.succeeded())
6440 p <<
" " << getMask();
6442 p <<
", " << getPassthru();
6446 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6453 p <<
" : " << getMask().getType();
6454 if (getNumResults() > 0)
6455 p <<
" -> " << getResultTypes();
6460 MaskOp>::ensureTerminator(region, builder, loc);
6472 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
6475 if (maskedOp == oldYieldOp)
6478 opBuilder.setInsertionPoint(oldYieldOp);
6479 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
6481 oldYieldOp->
erase();
6486 Block &block = getMaskRegion().getBlocks().
front();
6488 return emitOpError(
"expects a terminator within the mask region");
6491 if (numMaskRegionOps > 2)
6492 return emitOpError(
"expects only one operation to mask");
6495 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6497 return emitOpError(
"expects a terminator within the mask region");
6499 if (terminator->getNumOperands() != getNumResults())
6501 "expects number of results to match mask region yielded values");
6504 if (numMaskRegionOps == 1)
6507 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6509 return emitOpError(
"expects a MaskableOpInterface within the mask region");
6513 return emitOpError(
"expects number of results to match maskable operation "
6514 "number of results");
6516 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6518 "expects result type to match maskable operation result type");
6521 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6522 return emitOpError(
"multiple vector results not supported");
6525 Type expectedMaskType = maskableOp.getExpectedMaskType();
6526 if (getMask().
getType() != expectedMaskType)
6527 return emitOpError(
"expects a ")
6528 << expectedMaskType <<
" mask for the maskable operation";
6531 Value passthru = getPassthru();
6533 if (!maskableOp.supportsPassthru())
6535 "doesn't expect a passthru argument for this maskable operation");
6538 return emitOpError(
"expects result when passthru argument is provided");
6541 return emitOpError(
"expects passthru type to match result type");
6548 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6558 Operation *maskableOp = getMaskableOp();
6562 llvm::append_range(results, maskableOp->
getResults());
6574 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6575 if (maskingOp.getMaskableOp())
6578 if (!maskOp.isEmpty())
6581 Block *block = maskOp.getMaskBlock();
6582 auto terminator = cast<vector::YieldOp>(block->
front());
6583 if (terminator.getNumOperands() == 0)
6586 rewriter.
replaceOp(maskOp, terminator.getOperands());
6594 results.
add<ElideEmptyMaskOp>(context);
6601 Block *block = getMaskBlock();
6605 return &block->
front();
6609 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
6616 VectorType srcType = getSourceType();
6617 VectorType initialType = getInitialValueType();
6619 int64_t srcRank = srcType.getRank();
6620 int64_t reductionDim = getReductionDim();
6621 if (reductionDim >= srcRank)
6622 return emitOpError(
"reduction dimension ")
6623 << reductionDim <<
" has to be less than " << srcRank;
6626 int64_t initialValueRank = initialType.getRank();
6627 if (initialValueRank != srcRank - 1)
6628 return emitOpError(
"initial value rank ")
6629 << initialValueRank <<
" has to be equal to " << srcRank - 1;
6635 for (
int i = 0; i < srcRank; i++) {
6636 if (i != reductionDim)
6637 expectedShape.push_back(srcShape[i]);
6639 if (!llvm::equal(initialValueShapes, expectedShape)) {
6640 return emitOpError(
"incompatible input/initial value shapes");
6644 Type eltType = getDestType().getElementType();
6646 return emitOpError(
"unsupported reduction type ")
6647 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
6656 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6657 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6658 StridedSliceConstantMaskFolder, TransposeFolder>(
6667 auto constOperand = adaptor.getInput();
6668 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6677 setResultRanges(getResult(), argRanges.front());
6682 arith::FastMathFlagsAttr fastmath,
6689 case CombiningKind::ADD:
6692 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6693 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6695 llvm_unreachable(
"invalid value types for ADD reduction");
6697 case CombiningKind::AND:
6701 case CombiningKind::MAXNUMF:
6702 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6703 "expected float values");
6704 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6706 case CombiningKind::MAXIMUMF:
6707 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6708 "expected float values");
6709 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6711 case CombiningKind::MINNUMF:
6712 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6713 "expected float values");
6714 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6716 case CombiningKind::MINIMUMF:
6717 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6718 "expected float values");
6719 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6721 case CombiningKind::MAXSI:
6725 case CombiningKind::MINSI:
6729 case CombiningKind::MAXUI:
6737 case CombiningKind::MUL:
6740 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6741 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6743 llvm_unreachable(
"invalid value types for MUL reduction");
6745 case CombiningKind::OR:
6749 case CombiningKind::XOR:
6755 assert(result &&
"unknown CombiningKind");
6767 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
6787 return builder.
create<MaskOp>(maskableOp->getLoc(),
6788 maskableOp->getResultTypes(), mask, maskableOp,
6805 mask, newValue, passthru);
6812 #define GET_ATTRDEF_CLASSES
6813 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6815 #define GET_OP_CLASSES
6816 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.