39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 if (maskIdx < dimSize)
109 auto maskOperands = m.getOperands();
110 for (
Value operand : maskOperands) {
111 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
126 builder.
create<vector::YieldOp>(loc);
132 switch (combiningKind) {
133 case CombiningKind::ADD:
134 case CombiningKind::MUL:
137 case CombiningKind::MINSI:
138 case CombiningKind::MAXUI:
139 case CombiningKind::MAXSI:
140 case CombiningKind::AND:
141 case CombiningKind::OR:
142 case CombiningKind::XOR:
144 case CombiningKind::MINNUMF:
145 case CombiningKind::MAXNUMF:
146 case CombiningKind::MINIMUMF:
147 case CombiningKind::MAXIMUMF:
148 return llvm::isa<FloatType>(elementType);
154 VectorType vectorType) {
155 int64_t elementVectorRank = 0;
156 VectorType elementVectorType =
157 llvm::dyn_cast<VectorType>(shapedType.getElementType());
158 if (elementVectorType)
159 elementVectorRank += elementVectorType.getRank();
162 if (shapedType.getRank() == 0 &&
168 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169 shapedType.getContext());
176 vector::TransferReadOp read) {
177 auto readMask = read.getMask();
178 auto writeMask = write.getMask();
184 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185 if (!couldBeSameSplat)
190 m_Constant<DenseElementsAttr>(&splatAttr)) ||
202 vector::TransferReadOp read) {
203 return !defWrite.hasOutOfBoundsDim() &&
204 defWrite.getIndices() == read.getIndices() &&
205 defWrite.getVectorType() == read.getVectorType() &&
206 defWrite.getPermutationMap() == read.getPermutationMap() &&
207 ((!defWrite.getMask() && !read.getMask()) ||
212 vector::TransferWriteOp priorWrite) {
213 return priorWrite.getIndices() == write.getIndices() &&
214 priorWrite.getMask() == write.getMask() &&
215 priorWrite.getVectorType() == write.getVectorType() &&
216 priorWrite.getPermutationMap() == write.getPermutationMap();
220 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221 bool testDynamicValueUsingBounds) {
223 if (transferA.getVectorType() != transferB.getVectorType())
225 unsigned rankOffset = transferA.getLeadingShapedRank();
226 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227 Value indexA = transferA.getIndices()[i];
228 Value indexB = transferB.getIndices()[i];
232 if (i < rankOffset) {
235 if (cstIndexA.has_value() && cstIndexB.has_value()) {
236 if (*cstIndexA != *cstIndexB)
240 if (testDynamicValueUsingBounds) {
243 FailureOr<uint64_t> delta =
245 if (succeeded(delta) && *delta != 0)
248 FailureOr<bool> testEqual =
250 if (succeeded(testEqual) && !testEqual.value())
256 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257 if (cstIndexA.has_value() && cstIndexB.has_value()) {
258 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
259 if (distance >= vectorDim)
263 if (testDynamicValueUsingBounds) {
266 FailureOr<int64_t> delta =
268 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
271 FailureOr<int64_t> computeDelta =
273 if (succeeded(computeDelta)) {
274 if (
std::abs(computeDelta.value()) >= vectorDim)
284 VectorTransferOpInterface transferB,
285 bool testDynamicValueUsingBounds) {
286 if (transferA.getSource() != transferB.getSource())
289 testDynamicValueUsingBounds);
299 for (
auto [posInDim, dimSize, offsetInDim] :
300 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302 if (posInDim < dimSize + offsetInDim)
306 posInDim = offsetInDim;
316 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
318 assert(constOp &&
"Unexpected non-constant index");
319 return constOp.value();
329 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
330 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
331 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
341 llvm::transform(foldResults, std::back_inserter(values),
343 if (
auto attr = foldResult.dyn_cast<
Attribute>())
346 loc, cast<IntegerAttr>(attr).getInt())
349 return cast<Value>(foldResult);
360 auto lhs = mul.getLhs();
361 auto rhs = mul.getRhs();
362 if (lhs.getDefiningOp<vector::VectorScaleOp>())
364 if (rhs.getDefiningOp<vector::VectorScaleOp>())
412 void VectorDialect::initialize() {
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
423 addInterfaces<VectorInlinerInterface>();
425 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
428 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
430 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
439 return arith::ConstantOp::materialize(builder, value, type, loc);
455 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
458 CombiningKind kind) {
462 reductionDims.push_back(en.index());
463 build(builder, result, kind, source, acc, reductionDims);
466 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
468 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
473 std::optional<SmallVector<int64_t, 4>>
474 MultiDimReductionOp::getShapeForUnroll() {
475 return llvm::to_vector<4>(getSourceVectorType().
getShape());
481 Type inferredReturnType;
482 auto sourceScalableDims = getSourceVectorType().getScalableDims();
483 for (
auto [dimIdx, dimSize] :
485 if (!llvm::any_of(getReductionDims(),
486 [dimIdx = dimIdx](int64_t reductionDimIdx) {
487 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
489 targetShape.push_back(dimSize);
490 scalableDims.push_back(sourceScalableDims[dimIdx]);
493 if (targetShape.empty())
494 inferredReturnType = getSourceVectorType().getElementType();
497 targetShape, getSourceVectorType().
getElementType(), scalableDims);
498 if (
getType() != inferredReturnType)
499 return emitOpError() <<
"destination type " <<
getType()
500 <<
" is incompatible with source type "
501 << getSourceVectorType();
507 Type MultiDimReductionOp::getExpectedMaskType() {
508 auto vecType = getSourceVectorType();
511 vecType.getScalableDims());
520 struct ElideUnitDimsInMultiDimReduction
524 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
527 for (
const auto &dim :
enumerate(shape)) {
528 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
536 if (reductionOp.isMasked()) {
538 rootOp = reductionOp.getMaskingOp();
539 mask = reductionOp.getMaskingOp().getMask();
541 rootOp = reductionOp;
544 Location loc = reductionOp.getLoc();
545 Value acc = reductionOp.getAcc();
547 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
549 VectorType newMaskType =
551 dstVecType.getScalableDims());
552 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
554 cast = rewriter.
create<vector::ShapeCastOp>(
555 loc, reductionOp.getDestType(), reductionOp.getSource());
561 mask = rewriter.
create<vector::ExtractOp>(loc, mask, zeroIdx);
562 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource(),
568 cast,
nullptr, mask);
575 void MultiDimReductionOp::getCanonicalizationPatterns(
577 results.
add<ElideUnitDimsInMultiDimReduction>(context);
585 CombiningKind kind,
Value vector,
586 arith::FastMathFlags fastMathFlags) {
587 build(builder, result, kind, vector,
Value(), fastMathFlags);
592 arith::FastMathFlags fastMathFlags) {
593 build(builder, result,
594 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
600 int64_t rank = getSourceVectorType().getRank();
602 return emitOpError(
"unsupported reduction rank: ") << rank;
605 Type eltType = getDest().getType();
607 return emitOpError(
"unsupported reduction type '")
608 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
617 Type ReductionOp::getExpectedMaskType() {
618 auto vecType = getSourceVectorType();
621 vecType.getScalableDims());
628 case arith::AtomicRMWKind::addf:
629 case arith::AtomicRMWKind::addi:
630 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
631 CombiningKind::ADD, vector);
632 case arith::AtomicRMWKind::mulf:
633 case arith::AtomicRMWKind::muli:
634 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
635 CombiningKind::MUL, vector);
636 case arith::AtomicRMWKind::minimumf:
637 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
638 CombiningKind::MINIMUMF, vector);
639 case arith::AtomicRMWKind::mins:
640 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
641 CombiningKind::MINSI, vector);
642 case arith::AtomicRMWKind::minu:
643 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
645 case arith::AtomicRMWKind::maximumf:
646 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
647 CombiningKind::MAXIMUMF, vector);
648 case arith::AtomicRMWKind::maxs:
649 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
650 CombiningKind::MAXSI, vector);
651 case arith::AtomicRMWKind::maxu:
652 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
653 CombiningKind::MAXUI, vector);
654 case arith::AtomicRMWKind::andi:
655 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
656 CombiningKind::AND, vector);
657 case arith::AtomicRMWKind::ori:
658 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
659 CombiningKind::OR, vector);
668 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
669 return llvm::to_vector<4>(getSourceVectorType().
getShape());
676 LogicalResult matchAndRewrite(ReductionOp reductionOp,
681 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
684 if (maskableOp.isMasked()) {
686 rootOp = maskableOp.getMaskingOp();
687 mask = maskableOp.getMaskingOp().getMask();
689 rootOp = reductionOp;
692 auto vectorType = reductionOp.getSourceVectorType();
693 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
696 Location loc = reductionOp.getLoc();
698 if (vectorType.getRank() == 0) {
700 mask = rewriter.
create<ExtractElementOp>(loc, mask);
701 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
704 mask = rewriter.
create<ExtractOp>(loc, mask, 0);
705 result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector(), 0);
708 if (
Value acc = reductionOp.getAcc())
711 reductionOp.getFastmathAttr(), mask);
721 results.
add<ElideSingleElementReduction>(context);
735 getIndexingMapsAttrName(result.
name),
739 getIteratorTypesAttrName(result.
name),
742 return IteratorTypeAttr::get(builder.getContext(), t);
748 ArrayAttr indexingMaps,
749 ArrayAttr iteratorTypes) {
750 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
751 ContractionOp::getDefaultKind());
756 ArrayAttr indexingMaps,
757 ArrayAttr iteratorTypes, CombiningKind kind) {
774 DictionaryAttr dictAttr;
789 dictAttr.getValue().end());
795 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
800 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
801 auto maybeIteratorType = symbolizeIteratorType(s);
802 if (!maybeIteratorType.has_value())
803 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
805 iteratorTypeAttrs.push_back(
813 getKindAttrName(result.
name),
815 ContractionOp::getDefaultKind()));
817 if (masksInfo.empty())
819 if (masksInfo.size() != 2)
821 "expected zero or exactly 2 vector mask operands");
822 auto lhsType = llvm::cast<VectorType>(types[0]);
823 auto rhsType = llvm::cast<VectorType>(types[1]);
825 std::array<VectorType, 2> maskTypes = {
835 auto attrNames = getTraitAttrNames();
837 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
839 for (
auto attr : (*this)->getAttrs()) {
840 if (attr.getName() == getIteratorTypesAttrName()) {
842 llvm::cast<ArrayAttr>(attr.getValue())
843 .getAsValueRange<IteratorTypeAttr, IteratorType>();
849 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
853 attrs.emplace_back(getIteratorTypesAttrName(),
855 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
856 attrs.push_back(attr);
860 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
861 p << getRhs() <<
", " << getAcc();
864 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
869 const std::vector<std::pair<int64_t, int64_t>> &map) {
870 for (
auto &dimPair : map) {
871 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
872 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
873 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
880 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
882 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
883 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
886 for (
auto &dimPair : contractingDimMap) {
887 lhsContractingDimSet.insert(dimPair.first);
888 rhsContractingDimSet.insert(dimPair.second);
891 for (
auto &dimPair : batchDimMap)
892 rhsBatchDimSet.insert(dimPair.second);
896 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
897 if (lhsContractingDimSet.count(i) > 0)
899 expectedResultDims.push_back(lhsType.getDimSize(i));
903 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
904 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
906 expectedResultDims.push_back(rhsType.getDimSize(i));
910 if (expectedResultDims.empty()) {
912 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
913 return op.emitOpError(
"invalid accumulator/result vector shape");
916 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
917 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
918 if (!resVectorType || !accVectorType)
919 return op.emitOpError(
"invalid accumulator/result vector shape");
925 AffineMap lhsMap = op.getIndexingMapsArray()[0];
926 AffineMap rhsMap = op.getIndexingMapsArray()[1];
928 return op.emitOpError(
929 "expected all dimensions to be either a LHS or a RHS dimension");
932 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
933 VectorType v = pair.first;
934 auto map = pair.second;
935 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
936 unsigned pos = map.getDimPosition(idx);
941 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
942 return op.emitOpError(
"expected all dimensions to get an extent as "
943 "either a LHS or a RHS dimension");
945 AffineMap resMap = op.getIndexingMapsArray()[2];
951 llvm::IsaPred<AffineConstantExpr>) &&
952 "expected constant extent along all dimensions.");
954 auto expectedShape = llvm::to_vector<4>(
956 return cast<AffineConstantExpr>(e).getValue();
960 resVectorType.getScalableDims());
961 if (resVectorType != expected || accVectorType != expected)
962 return op.emitOpError(
963 "invalid accumulator/result vector shape, expected: ")
970 VectorType lhsType = getLhsType();
971 VectorType rhsType = getRhsType();
972 Type accType = getAccType();
973 Type resType = getResultType();
975 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
976 if (!lhsType.getElementType().isSignlessInteger())
977 return emitOpError(
"only supports signless integer types");
981 if (getIndexingMapsArray().size() != 3)
982 return emitOpError(
"expected an indexing map for each vector operand");
987 unsigned numIterators = getIteratorTypes().getValue().size();
989 auto index = it.index();
990 auto map = it.value();
991 if (map.getNumSymbols() != 0)
992 return emitOpError(
"expected indexing map ")
993 << index <<
" to have no symbols";
994 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
995 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
998 if (map.getNumDims() != numIterators)
999 return emitOpError(
"expected indexing map ")
1000 << index <<
" to have " << numIterators <<
" number of inputs";
1001 if (map.getNumResults() != rank)
1002 return emitOpError(
"expected indexing map ")
1003 << index <<
" to have " << rank <<
" number of outputs";
1004 if (!map.isProjectedPermutation())
1005 return emitOpError(
"expected indexing map ")
1006 << index <<
" to be a projected permutation of its inputs";
1009 auto contractingDimMap = getContractingDimMap();
1010 auto batchDimMap = getBatchDimMap();
1013 if (contractingDimMap.empty())
1014 return emitOpError(
"expected at least one contracting dimension pair");
1017 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1018 return emitOpError(
"invalid contracting dimension map");
1022 return emitOpError(
"invalid batch dimension map");
1026 contractingDimMap, batchDimMap)))
1030 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1031 auto elementType = vectorType ? vectorType.getElementType() : resType;
1033 return emitOpError(
"unsupported contraction type");
1042 Type ContractionOp::getExpectedMaskType() {
1043 auto indexingMaps = this->getIndexingMapsArray();
1046 VectorType lhsType = this->getLhsType();
1047 VectorType rhsType = this->getRhsType();
1049 unsigned numVecDims = lhsIdxMap.
getNumDims();
1058 lhsType.getScalableDims()[dimIdx];
1063 rhsType.getScalableDims()[dimIdx];
1066 assert(!ShapedType::isDynamicShape(maskShape) &&
1067 "Mask shape couldn't be computed");
1071 maskShapeScalableDims);
1076 getIteratorTypesAttrName(), getKindAttrName()};
1086 static std::vector<std::pair<int64_t, int64_t>>
1088 IteratorType targetIteratorType,
MLIRContext *context) {
1089 std::vector<std::pair<int64_t, int64_t>> dimMap;
1091 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1092 if (iteratorType != targetIteratorType)
1098 if (lhsDim >= 0 && rhsDim >= 0)
1099 dimMap.emplace_back(lhsDim, rhsDim);
1104 void ContractionOp::getIterationBounds(
1106 auto lhsShape = getLhsType().getShape();
1107 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1113 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1114 if (iteratorType == IteratorType::reduction) {
1116 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1117 assert(lhsDimIndex >= 0);
1118 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1122 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1123 assert(resDimIndex >= 0);
1124 assert(resVectorType !=
nullptr);
1125 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1129 void ContractionOp::getIterationIndexMap(
1131 unsigned numMaps = getIndexingMapsArray().size();
1132 iterationIndexMap.resize(numMaps);
1134 auto index = it.index();
1135 auto map = it.value();
1136 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1137 auto dim = cast<AffineDimExpr>(map.getResult(i));
1138 iterationIndexMap[index][dim.getPosition()] = i;
1143 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1145 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1149 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1151 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1155 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1157 getIterationBounds(shape);
1179 template <
typename AddOpType>
1185 auto canonicalize = [&](
Value maybeContraction,
1186 Value otherOperand) -> vector::ContractionOp {
1187 vector::ContractionOp contractionOp =
1188 dyn_cast_or_null<vector::ContractionOp>(
1191 return vector::ContractionOp();
1192 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1193 contractionOp.getAcc().getDefiningOp())) {
1194 if (maybeZero.getValue() ==
1195 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1197 bvm.
map(contractionOp.getAcc(), otherOperand);
1198 auto newContraction =
1199 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1200 rewriter.
replaceOp(addOp, newContraction.getResult());
1201 return newContraction;
1204 return vector::ContractionOp();
1207 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1208 vector::ContractionOp
contract = canonicalize(a, b);
1210 return contract ? success() : failure();
1226 setResultRanges(getResult(), argRanges.front());
1232 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1236 VectorType vectorType = getSourceVectorType();
1237 if (vectorType.getRank() == 0) {
1239 return emitOpError(
"expected position to be empty with 0-D vector");
1242 if (vectorType.getRank() != 1)
1243 return emitOpError(
"unexpected >1 vector rank");
1245 return emitOpError(
"expected position for 1-D vector");
1249 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1251 if (!adaptor.getPosition())
1255 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1256 return splat.getInput();
1259 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1263 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1264 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1268 auto srcElements = src.getValues<
Attribute>();
1270 uint64_t posIdx = pos.getInt();
1271 if (posIdx >= srcElements.size())
1274 return srcElements[posIdx];
1283 setResultRanges(getResult(), argRanges.front());
1287 Value source, int64_t position) {
1307 build(builder, result, source, dynamicPos,
1312 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1313 ExtractOp::Adaptor adaptor,
1315 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1316 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1317 vectorType.getRank()) {
1318 inferredReturnTypes.push_back(vectorType.getElementType());
1320 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1321 vectorType.getRank());
1323 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1324 vectorType.getScalableDims().drop_front(n)));
1332 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1333 return vectorType && vectorType.getShape().equals({1}) &&
1334 vectorType.getElementType() == r.front();
1336 if (l.size() == 1 && r.size() == 1 &&
1337 (isCompatible(l, r) || isCompatible(r, l)))
1344 auto dynamicMarkersCount =
1345 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1346 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1348 "mismatch between dynamic and static positions (kDynamic marker but no "
1349 "corresponding dynamic position) -- this can only happen due to an "
1350 "incorrect fold/rewrite");
1351 auto position = getMixedPosition();
1352 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1354 "expected position attribute of rank no greater than vector rank");
1356 if (
auto attr = dyn_cast<Attribute>(pos)) {
1357 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1358 if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1359 return emitOpError(
"expected position attribute #")
1361 <<
" to be a non-negative integer smaller than the "
1362 "corresponding vector dimension";
1369 template <
typename IntType>
1371 return llvm::to_vector<4>(llvm::map_range(
1372 arrayAttr.getAsRange<IntegerAttr>(),
1373 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1379 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1383 if (extractOp.hasDynamicPosition())
1387 ExtractOp currentOp = extractOp;
1389 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1390 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1393 if (currentOp.hasDynamicPosition())
1396 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1398 extractOp.setOperand(0, currentOp.getVector());
1401 std::reverse(globalPosition.begin(), globalPosition.end());
1402 extractOp.setStaticPosition(globalPosition);
1414 class ExtractFromInsertTransposeChainState {
1416 ExtractFromInsertTransposeChainState(ExtractOp e);
1425 template <
typename ContainerA,
typename ContainerB>
1426 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1427 return a.size() <= b.size() &&
1428 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1435 template <
typename ContainerA,
typename ContainerB>
1436 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1437 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1438 if (elemA < 0 || elemB < 0)
1453 void updateStateForNextIteration(
Value v) {
1460 LogicalResult handleTransposeOp();
1463 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1478 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1483 Value tryToFoldExtractOpInPlace(
Value source);
1485 ExtractOp extractOp;
1487 int64_t extractedRank;
1489 InsertOp nextInsertOp;
1490 TransposeOp nextTransposeOp;
1505 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1507 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1508 extractedRank(extractOp.getNumIndices()) {
1509 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1510 sentinels.reserve(vectorRank - extractedRank);
1511 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1512 sentinels.push_back(-(i + 1));
1514 extractOp.getStaticPosition().end());
1520 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1522 if (extractOp.hasDynamicPosition())
1525 if (!nextTransposeOp)
1528 nextTransposeOp.getPermutation(), extractOp.getContext()));
1535 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1538 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1545 res = nextInsertOp.getSource();
1547 return success(canFold());
1554 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1556 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1569 res = nextInsertOp.getSource();
1577 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1580 if (extractOp.hasDynamicPosition())
1584 bool nothingToFold = (source == extractOp.getVector());
1585 if (nothingToFold || !canFold())
1590 extractOp.setStaticPosition(
1592 extractOp.getVectorMutable().assign(source);
1593 return extractOp.getResult();
1597 Value ExtractFromInsertTransposeChainState::fold() {
1599 if (extractOp.hasDynamicPosition())
1602 Value valueToExtractFrom = extractOp.getVector();
1603 updateStateForNextIteration(valueToExtractFrom);
1604 while (nextInsertOp || nextTransposeOp) {
1607 if (succeeded(handleTransposeOp())) {
1608 valueToExtractFrom = nextTransposeOp.getVector();
1609 updateStateForNextIteration(valueToExtractFrom);
1615 if (succeeded(handleInsertOpWithMatchingPos(result)))
1620 if (succeeded(handleInsertOpWithPrefixPos(result)))
1621 return tryToFoldExtractOpInPlace(result);
1631 valueToExtractFrom = nextInsertOp.getDest();
1632 updateStateForNextIteration(valueToExtractFrom);
1635 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1640 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1641 auto vecType = dyn_cast<VectorType>(type);
1642 return vecType && vecType.getRank() == 0;
1652 if (extractOp.hasDynamicPosition())
1655 Operation *defOp = extractOp.getVector().getDefiningOp();
1656 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1660 if (extractOp.getType() == source.
getType())
1662 auto getRank = [](
Type type) {
1663 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1668 unsigned broadcastSrcRank = getRank(source.
getType());
1669 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1672 unsigned extractResultRank = getRank(extractOp.getType());
1673 if (extractResultRank >= broadcastSrcRank)
1676 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1677 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1678 if (extractVecType && broadcastVecType &&
1679 extractVecType.getShape() !=
1680 broadcastVecType.getShape().take_back(extractResultRank))
1683 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1684 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1690 broadcastOp.computeBroadcastedUnitDims();
1692 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1693 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1694 if (broadcastedUnitDims.contains(i))
1698 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1699 extractPos.erase(extractPos.begin(),
1700 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1703 extractOp.setOperand(0, source);
1704 extractOp.setStaticPosition(extractPos);
1705 return extractOp.getResult();
1721 if (extractOp.hasDynamicPosition())
1724 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1729 if (shuffleOp.getResultVectorType().getRank() != 1)
1732 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1733 auto shuffleMask = shuffleOp.getMask();
1734 int64_t extractIdx = extractOp.getStaticPosition()[0];
1735 int64_t shuffleIdx = shuffleMask[extractIdx];
1738 if (shuffleIdx < inputVecSize) {
1739 extractOp.setOperand(0, shuffleOp.getV1());
1740 extractOp.setStaticPosition({shuffleIdx});
1742 extractOp.setOperand(0, shuffleOp.getV2());
1743 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1746 return extractOp.getResult();
1752 if (extractOp.hasDynamicPosition())
1755 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1760 auto getDimReverse = [](VectorType type, int64_t n) {
1761 return type.getShape().take_back(n + 1).front();
1763 int64_t destinationRank =
1764 llvm::isa<VectorType>(extractOp.getType())
1765 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1767 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1769 if (destinationRank > 0) {
1770 auto destinationType =
1771 llvm::cast<VectorType>(extractOp.getResult().getType());
1772 for (int64_t i = 0; i < destinationRank; i++) {
1776 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1777 getDimReverse(destinationType, i))
1784 std::reverse(extractedPos.begin(), extractedPos.end());
1787 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1788 strides.push_back(stride);
1790 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1793 int64_t position =
linearize(extractedPos, strides);
1797 int64_t numDimension =
1798 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1800 for (int64_t i = 0; i < numDimension; i++) {
1801 newStrides.push_back(stride);
1803 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1805 std::reverse(newStrides.begin(), newStrides.end());
1809 extractOp.setStaticPosition(newPosition);
1810 extractOp.setOperand(0, shapeCastOp.getSource());
1811 return extractOp.getResult();
1817 if (extractOp.hasDynamicPosition())
1820 auto extractStridedSliceOp =
1821 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1822 if (!extractStridedSliceOp)
1831 if (extractStridedSliceOp.hasNonUnitStrides())
1836 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1837 while (!sliceOffsets.empty()) {
1838 size_t lastOffset = sliceOffsets.size() - 1;
1839 if (sliceOffsets.back() != 0 ||
1840 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1841 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1843 sliceOffsets.pop_back();
1845 unsigned destinationRank = 0;
1846 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1847 destinationRank = vecType.getRank();
1850 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1851 sliceOffsets.size())
1855 assert(extractedPos.size() >= sliceOffsets.size());
1856 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1857 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1858 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1862 extractOp.setStaticPosition(extractedPos);
1863 return extractOp.getResult();
1869 if (extractOp.hasDynamicPosition())
1872 int64_t destinationRank =
1873 llvm::isa<VectorType>(extractOp.getType())
1874 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1876 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1886 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1887 insertOp.getSourceVectorType().getRank();
1888 if (destinationRank > insertOp.getSourceVectorType().getRank())
1890 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1893 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1894 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1897 bool disjoint =
false;
1899 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1900 int64_t start = insertOffsets[dim];
1902 (dim < insertRankDiff)
1904 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1905 int64_t end = start + size;
1906 int64_t offset = extractOffsets[dim];
1908 if (start <= offset && offset < end) {
1909 if (dim >= insertRankDiff)
1910 offsetDiffs.push_back(offset - start);
1920 int64_t srcRankDiff =
1921 insertOp.getSourceVectorType().getRank() - destinationRank;
1922 for (int64_t i = 0; i < destinationRank; i++) {
1923 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1924 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1928 extractOp.getVectorMutable().assign(insertOp.getSource());
1931 extractOp.setStaticPosition(offsetDiffs);
1932 return extractOp.getResult();
1936 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1949 if (extractOp.hasDynamicPosition())
1953 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1954 if (!fromElementsOp)
1958 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1959 if (vecType.isScalable())
1963 int64_t rank = vecType.getRank();
1965 if (extractOp.getType() != vecType.getElementType())
1967 assert(
static_cast<int64_t
>(indices.size()) == rank &&
1968 "unexpected number of indices");
1973 for (
int i = rank - 1; i >= 0; --i) {
1974 flatIndex += indices[i] * stride;
1975 stride *= vecType.getDimSize(i);
1977 return fromElementsOp.getElements()[flatIndex];
1984 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
1988 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2014 Operation *defOp = extractOp.getVector().getDefiningOp();
2015 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2019 if (extractOp.getType() == source.
getType())
2021 auto getRank = [](
Type type) {
2022 return llvm::isa<VectorType>(type)
2023 ? llvm::cast<VectorType>(type).getRank()
2026 unsigned broadcastSrcRank = getRank(source.
getType());
2027 unsigned extractResultRank = getRank(extractOp.getType());
2031 if (extractResultRank < broadcastSrcRank)
2035 if (extractResultRank == 0) {
2036 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
2041 extractOp, extractOp.getType(), source);
2047 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
2055 Value sourceVector = extractOp.getVector();
2059 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2062 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2063 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2071 class ExtractOpNonSplatConstantFolder final
2079 if (extractOp.hasDynamicPosition())
2084 Value sourceVector = extractOp.getVector();
2089 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
2090 if (vecTy.isScalable())
2094 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2095 if (!dense || dense.isSplat())
2101 copy(extractOp.getStaticPosition(), completePositions.begin());
2102 int64_t elemBeginPosition =
2104 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2107 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2109 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2112 newAttr = *denseValuesBegin;
2128 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2132 VectorType extractedMaskType =
2133 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2135 if (!extractedMaskType)
2138 auto maskOperands = createMaskOp.getOperands();
2140 VectorType maskType = createMaskOp.getVectorType();
2142 bool containsUnknownDims =
false;
2145 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2147 int64_t pos = extractOpPos[dimIdx];
2148 Value operand = maskOperands[dimIdx];
2149 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2152 containsUnknownDims =
true;
2156 int64_t createMaskBound =
2157 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2159 if (pos != ShapedType::kDynamic) {
2162 allFalse |= pos >= createMaskBound;
2163 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2167 containsUnknownDims =
true;
2174 }
else if (!containsUnknownDims) {
2176 extractOp, extractedMaskType,
2177 maskOperands.drop_front(extractOpPos.size()));
2187 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2189 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2193 VectorType sourceType = castOp.getSourceVectorType();
2194 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2198 if (sourceType.getNumElements() != targetType.getNumElements())
2202 castOp.getSource());
2212 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2215 if (extractOp.hasDynamicPosition())
2219 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2224 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2225 if (!fromElementsOp)
2227 VectorType inputType = fromElementsOp.getType();
2230 if (resultType.isScalable() || inputType.isScalable())
2236 llvm::to_vector(extractOp.getStaticPosition());
2237 firstElementPos.append(resultType.getRank(), 0);
2240 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2241 flatIndex += firstElementPos[i] * stride;
2242 stride *= inputType.getDimSize(i);
2247 extractOp, resultType,
2248 fromElementsOp.getElements().slice(flatIndex,
2249 resultType.getNumElements()));
2256 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2257 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2258 results.
add(foldExtractFromShapeCastToShapeCast);
2259 results.
add(foldExtractFromFromElements);
2264 for (
auto attr : arrayAttr)
2265 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2272 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2287 if (!llvm::all_equal(fromElementsOp.getElements()))
2290 fromElementsOp.getElements().front());
2305 setResultRanges(getResult(), argRanges.front());
2313 int64_t rankDiff = dstShape.size() - srcShape.size();
2314 int64_t dstDim = rankDiff;
2316 for (
auto [s1, s2] :
2317 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2319 assert(s1 == 1 &&
"expected dim-1 broadcasting");
2329 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2348 Value BroadcastOp::createOrFoldBroadcastOp(
2351 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2355 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2356 if (broadcastedDims.contains(i))
2358 checkShape.push_back(dstShape[i]);
2360 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2361 "ill-formed broadcastedDims contains values not confined to "
2366 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2370 if (!srcVectorType) {
2371 assert(checkShape.empty() &&
2372 "ill-formed createOrFoldBroadcastOp arguments");
2373 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2376 assert(srcVectorType.getShape().equals(checkShape) &&
2377 "ill-formed createOrFoldBroadcastOp arguments");
2388 broadcastShape.reserve(dstShape.size());
2404 int64_t nextSrcShapeDim = broadcastedDims.size();
2405 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2406 if (broadcastedDims.contains(i)) {
2411 broadcastShape.push_back(dstShape[i]);
2412 permutation[i] = broadcastShape.size() - 1;
2418 permutation[i] = nextSrcShapeDim++;
2422 llvm::append_range(broadcastShape, srcVectorType.getShape());
2427 "unexpected dim-1 broadcast");
2429 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2431 vector::BroadcastableToResult::Success &&
2432 "must be broadcastable");
2436 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2437 if (permutation[i] != i)
2438 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2444 Type srcType, VectorType dstVectorType,
2445 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2449 return BroadcastableToResult::Success;
2451 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2453 return BroadcastableToResult::SourceTypeNotAVector;
2455 int64_t srcRank = srcVectorType.getRank();
2456 int64_t dstRank = dstVectorType.getRank();
2457 if (srcRank > dstRank)
2458 return BroadcastableToResult::SourceRankHigher;
2461 int64_t lead = dstRank - srcRank;
2462 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2465 bool foundMismatchingDims =
false;
2468 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2469 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2470 if (srcDim != 1 && srcDim != dstDim)
2471 foundMismatchingDims =
true;
2474 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2475 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2476 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2479 (srcDimScalableFlag != dstDimScalableFlag &&
2480 (srcDim != 1 || srcDimScalableFlag)))
2481 foundMismatchingDims =
true;
2483 if (foundMismatchingDims) {
2484 if (mismatchingDims !=
nullptr) {
2485 mismatchingDims->first.dim = srcDim;
2486 mismatchingDims->first.isScalable = srcDimScalableFlag;
2488 mismatchingDims->second.dim = dstDim;
2489 mismatchingDims->second.isScalable = dstDimScalableFlag;
2491 return BroadcastableToResult::DimensionMismatch;
2495 return BroadcastableToResult::Success;
2499 std::pair<VectorDim, VectorDim> mismatchingDims;
2501 getSourceType(), getResultVectorType(), &mismatchingDims);
2502 if (res == BroadcastableToResult::Success)
2504 if (res == BroadcastableToResult::SourceRankHigher)
2505 return emitOpError(
"source rank higher than destination rank");
2506 if (res == BroadcastableToResult::DimensionMismatch) {
2507 return emitOpError(
"dimension mismatch (")
2508 << (mismatchingDims.first.isScalable ?
"[" :
"")
2509 << mismatchingDims.first.dim
2510 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2511 << (mismatchingDims.second.isScalable ?
"[" :
"")
2512 << mismatchingDims.second.dim
2513 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2515 if (res == BroadcastableToResult::SourceTypeNotAVector)
2516 return emitOpError(
"source type is not a vector");
2517 llvm_unreachable(
"unexpected vector.broadcast op error");
2521 if (getSourceType() == getResultVectorType())
2523 if (!adaptor.getSource())
2525 auto vectorType = getResultVectorType();
2526 if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2528 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2541 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2545 broadcastOp.getResultVectorType(),
2546 srcBroadcast.getSource());
2556 results.
add<BroadcastFolder>(context);
2564 VectorType resultType = getResultVectorType();
2565 VectorType v1Type = getV1VectorType();
2566 VectorType v2Type = getV2VectorType();
2568 int64_t resRank = resultType.getRank();
2569 int64_t v1Rank = v1Type.getRank();
2570 int64_t v2Rank = v2Type.getRank();
2571 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2572 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2573 if (!wellFormed0DCase && !wellFormedNDCase)
2574 return emitOpError(
"rank mismatch");
2577 for (int64_t r = 1; r < v1Rank; ++r) {
2578 int64_t resDim = resultType.getDimSize(r);
2579 int64_t v1Dim = v1Type.getDimSize(r);
2580 int64_t v2Dim = v2Type.getDimSize(r);
2581 if (resDim != v1Dim || v1Dim != v2Dim)
2582 return emitOpError(
"dimension mismatch");
2586 int64_t maskLength = mask.size();
2587 if (maskLength <= 0)
2588 return emitOpError(
"invalid mask length");
2589 if (maskLength != resultType.getDimSize(0))
2590 return emitOpError(
"mask length mismatch");
2592 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2593 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2595 if (maskPos < 0 || maskPos >= indexSize)
2596 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2602 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2603 ShuffleOp::Adaptor adaptor,
2605 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2606 auto v1Rank = v1Type.getRank();
2610 shape.reserve(v1Rank);
2611 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2614 llvm::append_range(shape, v1Type.getShape().drop_front());
2615 inferredReturnTypes.push_back(
2620 template <
typename T>
2623 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2624 return value == expected++;
2628 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2629 VectorType v1Type = getV1VectorType();
2632 if (v1Type.getRank() == 0)
2636 if (!v1Type.isScalable() &&
2640 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2642 getV2VectorType().getDimSize(0)))
2645 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2650 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).
getType());
2653 if (lhsType.getRank() != 1)
2655 int64_t lhsSize = lhsType.getDimSize(0);
2658 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<
Attribute>();
2659 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<
Attribute>();
2660 for (int64_t i : this->getMask()) {
2662 results.push_back(rhsElements[i - lhsSize]);
2664 results.push_back(lhsElements[i]);
2680 VectorType v1VectorType = shuffleOp.getV1VectorType();
2682 if (v1VectorType.getRank() > 0)
2684 if (mask.size() != 1)
2704 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2705 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2707 if (!v1Splat || !v2Splat)
2710 if (v1Splat.getInput() != v2Splat.getInput())
2726 VectorType resultType = op.getResultVectorType();
2727 if (resultType.isScalable())
2729 op,
"ShuffleOp can't represent a scalable interleave");
2731 if (resultType.getRank() != 1)
2733 op,
"ShuffleOp can't represent an n-D interleave");
2735 VectorType sourceType = op.getV1VectorType();
2736 if (sourceType != op.getV2VectorType() ||
2737 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2739 op,
"ShuffleOp types don't match an interleave");
2743 int64_t resultVectorSize = resultType.getNumElements();
2744 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2745 int64_t maskValueA = shuffleMask[i * 2];
2746 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2747 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2749 "ShuffleOp mask not interleaving");
2761 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2771 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2776 build(builder, result, source, dest, {});
2780 auto dstVectorType = getDestVectorType();
2781 if (dstVectorType.getRank() == 0) {
2783 return emitOpError(
"expected position to be empty with 0-D vector");
2786 if (dstVectorType.getRank() != 1)
2787 return emitOpError(
"unexpected >1 vector rank");
2789 return emitOpError(
"expected position for 1-D vector");
2793 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2795 if (!adaptor.getPosition())
2798 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2799 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2800 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2801 if (!src || !dst || !pos)
2807 auto dstElements = dst.getValues<
Attribute>();
2811 uint64_t posIdx = pos.getInt();
2812 if (posIdx >= results.size())
2814 results[posIdx] = src;
2825 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2829 Value source,
Value dest, int64_t position) {
2842 posVals.reserve(position.size());
2843 llvm::transform(position, std::back_inserter(posVals),
2845 build(builder, result, source, dest, posVals);
2854 build(builder, result, source, dest, dynamicPos,
2860 auto destVectorType = getDestVectorType();
2861 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2863 "expected position attribute of rank no greater than dest vector rank");
2864 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2865 if (srcVectorType &&
2866 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2867 static_cast<unsigned>(destVectorType.getRank())))
2868 return emitOpError(
"expected position attribute rank + source rank to "
2869 "match dest vector rank");
2870 if (!srcVectorType &&
2871 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2873 "expected position attribute rank to match the dest vector rank");
2875 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2876 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2877 if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2878 return emitOpError(
"expected position attribute #")
2880 <<
" to be a non-negative integer smaller than the "
2882 "dest vector dimension";
2899 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2900 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2901 srcVecType.getNumElements())
2904 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2916 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2917 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2919 if (!srcSplat || !dstSplat)
2922 if (srcSplat.getInput() != dstSplat.getInput())
2937 static constexpr int64_t vectorSizeFoldThreshold = 256;
2942 if (op.hasDynamicPosition())
2951 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2955 VectorType destTy = destVector.getType();
2956 if (destTy.isScalable())
2960 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2961 !destVector.hasOneUse())
2964 Value sourceValue = op.getSource();
2972 copy(op.getStaticPosition(), completePositions.begin());
2973 int64_t insertBeginPosition =
2977 Type destEltType = destTy.getElementType();
2982 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2983 for (
auto value : denseSource.getValues<
Attribute>())
2989 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
2990 copy(insertedValues, allValues.begin() + insertBeginPosition);
3001 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3002 if (intAttr.getType() != expectedType)
3013 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3014 InsertOpConstantFolder>(context);
3017 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3021 if (getNumIndices() == 0 && getSourceType() ==
getType())
3045 template <
typename OpType>
3047 ArrayAttr arrayAttr,
3049 StringRef attrName) {
3050 if (arrayAttr.size() > shape.size())
3051 return op.emitOpError(
"expected ")
3052 << attrName <<
" attribute of rank no greater than vector rank";
3059 template <
typename OpType>
3060 static LogicalResult
3062 int64_t
max, StringRef attrName,
3063 bool halfOpen =
true) {
3064 for (
auto attr : arrayAttr) {
3065 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3069 if (val < min || val >= upper)
3070 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3071 <<
min <<
", " << upper <<
")";
3079 template <
typename OpType>
3080 static LogicalResult
3083 bool halfOpen =
true, int64_t
min = 0) {
3084 for (
auto [index, attrDimPair] :
3086 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3087 int64_t
max = std::get<1>(attrDimPair);
3090 if (val < min || val >=
max)
3091 return op.emitOpError(
"expected ")
3092 << attrName <<
" dimension " << index <<
" to be confined to ["
3093 <<
min <<
", " <<
max <<
")";
3103 template <
typename OpType>
3105 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3107 bool halfOpen =
true, int64_t
min = 1) {
3108 assert(arrayAttr1.size() <= shape.size());
3109 assert(arrayAttr2.size() <= shape.size());
3110 for (
auto [index, it] :
3112 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3113 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3114 int64_t
max = std::get<2>(it);
3117 if (val1 + val2 < 0 || val1 + val2 >=
max)
3118 return op.emitOpError(
"expected sum(")
3119 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3120 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3127 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3134 auto sourceVectorType = getSourceVectorType();
3135 auto destVectorType = getDestVectorType();
3136 auto offsets = getOffsetsAttr();
3137 auto strides = getStridesAttr();
3138 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3140 "expected offsets of same size as destination vector rank");
3141 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3142 return emitOpError(
"expected strides of same size as source vector rank");
3143 if (sourceVectorType.getRank() > destVectorType.getRank())
3145 "expected source rank to be no greater than destination rank");
3147 auto sourceShape = sourceVectorType.getShape();
3148 auto destShape = destVectorType.getShape();
3150 destShape.size() - sourceShape.size(), 0);
3151 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3152 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3153 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3162 offName,
"source vector shape",
3166 unsigned rankDiff = destShape.size() - sourceShape.size();
3167 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3168 if (sourceVectorType.getScalableDims()[idx] !=
3169 destVectorType.getScalableDims()[idx + rankDiff]) {
3170 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3173 if (sourceVectorType.getScalableDims()[idx]) {
3174 auto sourceSize = sourceShape[idx];
3175 auto destSize = destShape[idx + rankDiff];
3176 if (sourceSize != destSize) {
3177 return emitOpError(
"expected size at idx=")
3179 << (
" to match the corresponding base size from the input "
3181 << sourceSize << (
" vs ") << destSize << (
")");
3192 class FoldInsertStridedSliceSplat final
3197 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3200 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3202 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3204 if (!srcSplatOp || !destSplatOp)
3207 if (srcSplatOp.getInput() != destSplatOp.getInput())
3210 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3217 class FoldInsertStridedSliceOfExtract final
3222 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3224 auto extractStridedSliceOp =
3225 insertStridedSliceOp.getSource()
3226 .getDefiningOp<vector::ExtractStridedSliceOp>();
3228 if (!extractStridedSliceOp)
3231 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3235 if (extractStridedSliceOp.getStrides() !=
3236 insertStridedSliceOp.getStrides() ||
3237 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3240 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3247 class InsertStridedSliceConstantFolder final
3254 static constexpr int64_t vectorSizeFoldThreshold = 256;
3265 VectorType destTy = destVector.getType();
3266 if (destTy.isScalable())
3270 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3271 !destVector.hasOneUse())
3274 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3282 if (op.hasNonUnitStrides())
3285 VectorType sliceVecTy = sourceValue.getType();
3287 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3297 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3298 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3299 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3302 currDestPosition.begin() + rankDifference, currDestPosition.end());
3306 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3307 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3308 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3309 "Invalid slice element");
3310 newValues[linearizedPosition] = *sliceValuesIt;
3323 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3325 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3326 InsertStridedSliceConstantFolder>(context);
3329 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3330 if (getSourceVectorType() == getDestVectorType())
3347 p <<
" " << getLhs() <<
", " << getRhs();
3349 p <<
", " << getAcc();
3352 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3363 if (operandsInfo.size() < 2)
3365 "expected at least 2 operands");
3366 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3367 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3370 "expected vector type for operand #1");
3375 vRHS.getScalableDims()[0]};
3377 vLHS.getElementType(), scalableDimsRes);
3381 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3387 OuterProductOp::getKindAttrName(result.
name),
3389 OuterProductOp::getDefaultKind()));
3395 (operandsInfo.size() > 2 &&
3401 Type tRHS = getOperandTypeRHS();
3402 VectorType vLHS = getOperandVectorTypeLHS(),
3403 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3404 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3406 if (vLHS.getRank() != 1)
3407 return emitOpError(
"expected 1-d vector for operand #1");
3411 if (vRHS.getRank() != 1)
3412 return emitOpError(
"expected 1-d vector for operand #2");
3413 if (vRES.getRank() != 2)
3414 return emitOpError(
"expected 2-d vector result");
3415 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3416 return emitOpError(
"expected #1 operand dim to match result dim #1");
3417 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3418 return emitOpError(
"expected #2 operand dim to match result dim #2");
3419 if (vLHS.isScalable() && !vRHS.isScalable()) {
3423 "expected either both or only #2 operand dim to be scalable");
3427 if (vRES.getRank() != 1)
3428 return emitOpError(
"expected 1-d vector result");
3429 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3430 return emitOpError(
"expected #1 operand dim to match result dim #1");
3433 if (vACC && vACC != vRES)
3434 return emitOpError(
"expected operand #3 of same type as result type");
3438 return emitOpError(
"unsupported outerproduct type");
3447 Type OuterProductOp::getExpectedMaskType() {
3448 auto vecType = this->getResultVectorType();
3451 vecType.getScalableDims());
3463 ArrayAttr offsets, ArrayAttr sizes,
3464 ArrayAttr strides) {
3465 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3467 shape.reserve(vectorType.getRank());
3469 for (
unsigned e = offsets.size(); idx < e; ++idx)
3470 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3471 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3472 shape.push_back(vectorType.getShape()[idx]);
3475 vectorType.getScalableDims());
3488 offsetsAttr, sizesAttr, stridesAttr));
3489 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3493 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3498 auto type = getSourceVectorType();
3499 auto offsets = getOffsetsAttr();
3500 auto sizes = getSizesAttr();
3501 auto strides = getStridesAttr();
3502 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3504 "expected offsets, sizes and strides attributes of same size");
3506 auto shape = type.getShape();
3507 auto offName = getOffsetsAttrName();
3508 auto sizesName = getSizesAttrName();
3509 auto stridesName = getStridesAttrName();
3525 shape, offName, sizesName,
3530 offsets, sizes, strides);
3531 if (getResult().
getType() != resultType)
3532 return emitOpError(
"expected result type to be ") << resultType;
3534 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3535 if (type.getScalableDims()[idx]) {
3536 auto inputDim = type.getShape()[idx];
3537 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3538 if (inputDim != inputSize)
3539 return emitOpError(
"expected size at idx=")
3541 << (
" to match the corresponding base size from the input "
3543 << inputSize << (
" vs ") << inputDim << (
")");
3553 static LogicalResult
3556 auto getElement = [](ArrayAttr array,
int idx) {
3557 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3559 ArrayAttr extractOffsets = op.getOffsets();
3561 ArrayAttr extractSizes = op.getSizes();
3562 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3564 if (op.getSourceVectorType().getRank() !=
3565 insertOp.getSourceVectorType().getRank())
3567 ArrayAttr insertOffsets = insertOp.getOffsets();
3568 ArrayAttr insertStrides = insertOp.getStrides();
3571 if (extractOffsets.size() > insertOffsets.size())
3573 bool patialoverlap =
false;
3574 bool disjoint =
false;
3576 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3577 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3579 int64_t start = getElement(insertOffsets, dim);
3580 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3581 int64_t offset = getElement(extractOffsets, dim);
3582 int64_t size = getElement(extractSizes, dim);
3584 if (start <= offset && offset < end) {
3587 if (offset + size > end)
3588 patialoverlap =
true;
3589 offsetDiffs.push_back(offset - start);
3596 if (!disjoint && !patialoverlap) {
3597 op.setOperand(insertOp.getSource());
3606 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3616 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3617 if (getSourceVectorType() == getResult().
getType())
3632 class StridedSliceConstantMaskFolder final
3637 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3641 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3642 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3643 if (!constantMaskOp)
3646 if (extractStridedSliceOp.hasNonUnitStrides())
3659 sliceMaskDimSizes.reserve(maskDimSizes.size());
3660 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3661 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3662 int64_t sliceMaskDimSize =
std::max(
3663 static_cast<int64_t
>(0),
3664 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3665 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3668 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3669 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3670 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3673 if (llvm::is_contained(sliceMaskDimSizes, 0))
3674 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3679 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3686 class StridedSliceSplatConstantFolder final
3691 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3695 Value sourceVector = extractStridedSliceOp.getVector();
3700 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3714 class StridedSliceNonSplatConstantFolder final
3719 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3723 Value sourceVector = extractStridedSliceOp.getVector();
3729 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3730 if (!dense || dense.isSplat())
3734 if (extractStridedSliceOp.hasNonUnitStrides())
3737 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3741 VectorType sliceVecTy = extractStridedSliceOp.getType();
3743 int64_t sliceRank = sliceVecTy.getRank();
3755 auto denseValuesBegin = dense.value_begin<
Attribute>();
3757 sliceValues.reserve(sliceVecTy.getNumElements());
3760 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3761 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3763 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3767 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3768 sliceVecTy.getNumElements() &&
3769 "Invalid number of slice elements");
3779 class StridedSliceBroadcast final
3791 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3792 auto dstVecType = llvm::cast<VectorType>(op.getType());
3793 unsigned dstRank = dstVecType.getRank();
3794 unsigned rankDiff = dstRank - srcRank;
3798 bool lowerDimMatch =
true;
3799 for (
unsigned i = 0; i < srcRank; i++) {
3800 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3801 lowerDimMatch =
false;
3810 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3811 if (!lowerDimMatch && !isScalarSrc) {
3812 source = rewriter.
create<ExtractStridedSliceOp>(
3813 op->getLoc(), source,
3824 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3830 auto splat = op.getVector().getDefiningOp<SplatOp>();
3854 class ContiguousExtractStridedSliceToExtract final
3861 if (op.hasNonUnitStrides())
3863 Value source = op.getOperand();
3864 auto sourceType = cast<VectorType>(source.
getType());
3865 if (sourceType.isScalable() || sourceType.getRank() == 0)
3874 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3875 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3882 if (numOffsets == 0)
3887 if (numOffsets == sourceType.getRank() &&
3888 static_cast<int>(sizes.size()) == sourceType.getRank())
3892 for (
int i = 0; i < numOffsets; ++i) {
3900 while (sizes[numOffsets] == 1 &&
3901 numOffsets <
static_cast<int>(sizes.size()) - 1) {
3906 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
3907 Value extract = rewriter.
create<vector::ExtractOp>(op->getLoc(), source,
3916 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3920 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3921 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3922 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3932 VectorType vectorType,
Value source,
3933 ValueRange indices, AffineMapAttr permutationMapAttr,
3934 ArrayAttr inBoundsAttr) {
3935 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3936 Value padding = builder.
create<arith::ConstantOp>(
3938 build(builder, result, vectorType, source, indices, permutationMapAttr,
3939 padding,
Value(), inBoundsAttr);
3944 VectorType vectorType,
Value source,
3948 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3952 build(builder, result, vectorType, source, indices, permutationMapAttr,
3958 VectorType vectorType,
Value source,
3962 llvm::cast<ShapedType>(source.
getType()), vectorType);
3964 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3968 build(builder, result, vectorType, source, indices, permutationMapAttr,
3970 Value(), inBoundsAttr);
3976 VectorType vectorType,
Value source,
3979 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3980 Value padding = builder.
create<arith::ConstantOp>(
3982 build(builder, result, vectorType, source, indices, padding, inBounds);
3985 template <
typename EmitFun>
3987 EmitFun emitOpError) {
3989 for (
auto expr : permutationMap.
getResults()) {
3990 auto dim = dyn_cast<AffineDimExpr>(expr);
3991 auto zero = dyn_cast<AffineConstantExpr>(expr);
3993 if (zero.getValue() != 0) {
3995 "requires a projected permutation_map (at most one dim or the zero "
3996 "constant can appear in each result)");
4001 return emitOpError(
"requires a projected permutation_map (at most one "
4002 "dim or the zero constant can appear in each result)");
4004 if (seen[dim.getPosition()]) {
4006 "requires a permutation_map that is a permutation (found one dim "
4007 "used more than once)");
4009 seen[dim.getPosition()] =
true;
4014 static LogicalResult
4016 VectorType vectorType, VectorType maskType,
4017 VectorType inferredMaskType,
AffineMap permutationMap,
4018 ArrayAttr inBounds) {
4019 if (op->hasAttr(
"masked")) {
4020 return op->emitOpError(
"masked attribute has been removed. "
4021 "Use in_bounds instead.");
4024 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4025 return op->emitOpError(
4026 "requires source to be a memref or ranked tensor type");
4028 auto elementType = shapedType.getElementType();
4029 DataLayout dataLayout = DataLayout::closest(op);
4030 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4032 unsigned sourceVecSize =
4034 vectorElementType.getShape().back();
4035 unsigned resultVecSize =
4037 vectorType.getShape().back();
4038 if (resultVecSize % sourceVecSize != 0)
4039 return op->emitOpError(
4040 "requires the bitwidth of the minor 1-D vector to be an integral "
4041 "multiple of the bitwidth of the minor 1-D vector of the source");
4043 unsigned sourceVecEltRank = vectorElementType.getRank();
4044 unsigned resultVecRank = vectorType.getRank();
4045 if (sourceVecEltRank > resultVecRank)
4046 return op->emitOpError(
4047 "requires source vector element and vector result ranks to match.");
4048 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4051 return op->emitOpError(
"requires a permutation_map with result dims of "
4052 "the same rank as the vector type");
4055 return op->emitOpError(
"does not support masks with vector element type");
4058 unsigned minorSize =
4059 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4060 unsigned resultVecSize =
4063 return op->emitOpError(
4064 "requires the bitwidth of the minor 1-D vector to be an integral "
4065 "multiple of the bitwidth of the source element type");
4069 return op->emitOpError(
"requires a permutation_map with result dims of "
4070 "the same rank as the vector type");
4074 return op->emitOpError(
"requires permutation_map without symbols");
4076 if (permutationMap.
getNumInputs() != shapedType.getRank())
4077 return op->emitOpError(
"requires a permutation_map with input dims of the "
4078 "same rank as the source type");
4080 if (maskType && maskType != inferredMaskType)
4081 return op->emitOpError(
"inferred mask type (")
4082 << inferredMaskType <<
") and mask operand type (" << maskType
4085 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4086 return op->emitOpError(
"expects the in_bounds attr of same rank "
4087 "as permutation_map results: ")
4089 <<
" vs inBounds of size: " << inBounds.size();
4096 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4097 if (op.getPermutationMap().isMinorIdentity())
4098 elidedAttrs.push_back(op.getPermutationMapAttrName());
4100 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4101 elidedAttrs.push_back(op.getInBoundsAttrName());
4106 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
4108 p <<
", " << getMask();
4117 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4122 if (maskShape.empty())
4123 maskShape.push_back(1);
4145 if (hasMask.succeeded()) {
4152 if (types.size() != 2)
4153 return parser.
emitError(typesLoc,
"requires two types");
4155 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4156 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4157 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4158 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4160 return parser.
emitError(typesLoc,
"requires vector type");
4161 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4168 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4170 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4172 if (!inBoundsAttr) {
4182 if (hasMask.succeeded()) {
4183 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4185 maskInfo.
location,
"does not support masks with vector element type");
4188 "expected the same rank for the vector and the "
4189 "results of the permutation map");
4197 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4199 {1, static_cast<int32_t>(indexInfo.size()), 1,
4200 static_cast<int32_t>(hasMask.succeeded())}));
4206 ShapedType shapedType = getShapedType();
4208 VectorType maskType = getMaskType();
4209 auto paddingType = getPadding().getType();
4210 auto permutationMap = getPermutationMap();
4211 VectorType inferredMaskType =
4214 auto sourceElementType = shapedType.getElementType();
4216 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4217 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4219 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4220 shapedType, vectorType, maskType,
4221 inferredMaskType, permutationMap, getInBounds())))
4224 if (
auto sourceVectorElementType =
4225 llvm::dyn_cast<VectorType>(sourceElementType)) {
4228 if (sourceVectorElementType != paddingType)
4230 "requires source element type and padding type to match.");
4234 if (!VectorType::isValidElementType(paddingType))
4235 return emitOpError(
"requires valid padding vector elemental type");
4238 if (paddingType != sourceElementType)
4240 "requires formal padding and source of the same elemental type");
4244 [&](Twine t) {
return emitOpError(t); });
4251 Type TransferReadOp::getExpectedMaskType() {
4255 template <
typename TransferOp>
4256 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4259 if (op.getShapedType().isDynamicDim(indicesIdx))
4261 Value index = op.getIndices()[indicesIdx];
4263 if (!cstOp.has_value())
4266 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4267 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4269 return cstOp.value() + vectorSize <= sourceSize;
4272 template <
typename TransferOp>
4276 if (op.getTransferRank() == 0)
4281 newInBounds.reserve(op.getTransferRank());
4286 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4288 if (op.isDimInBounds(i)) {
4289 newInBounds.push_back(
true);
4294 bool inBounds =
false;
4295 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4298 dimExpr.getPosition());
4299 nonBcastDims.push_back(i);
4302 newInBounds.push_back(inBounds);
4310 bool allNonBcastDimsInBounds = llvm::all_of(
4311 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4312 if (allNonBcastDimsInBounds) {
4315 newInBounds[idx] =
true;
4327 template <
typename TransferOp>
4329 auto mask = op.getMask();
4336 op.getMaskMutable().clear();
4350 static Value foldRAW(TransferReadOp readOp) {
4351 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4353 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4356 return defWrite.getVector();
4358 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4359 cast<VectorTransferOpInterface>(readOp.getOperation())))
4361 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4367 if (
Value vec = foldRAW(*
this))
4381 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4385 void TransferReadOp::getEffects(
4388 if (llvm::isa<MemRefType>(getShapedType()))
4394 if (hasPureTensorSemantics())
4422 struct TransferReadAfterWriteToBroadcast
4428 if (readOp.hasOutOfBoundsDim() ||
4429 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4431 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4436 if (readOp.getTransferChunkAccessed() !=
4437 defWrite.getTransferChunkAccessed())
4444 if (readOp.getIndices() != defWrite.getIndices() ||
4445 readOp.getMask() != defWrite.getMask())
4447 Value vec = defWrite.getVector();
4469 broadcastShape[pos.value()] = destShape[pos.index()];
4470 broadcastScalableFlags[pos.value()] =
4471 readOp.getVectorType().getScalableDims()[pos.index()];
4474 broadcastShape, defWrite.getVectorType().getElementType(),
4475 broadcastScalableFlags);
4476 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4487 results.
add<TransferReadAfterWriteToBroadcast>(context);
4497 AffineMapAttr permutationMapAttr,
4499 ArrayAttr inBoundsAttr) {
4500 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4501 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4502 mask, inBoundsAttr);
4508 AffineMapAttr permutationMapAttr,
4509 ArrayAttr inBoundsAttr) {
4510 build(builder, result, vector, dest, indices, permutationMapAttr,
4511 Value(), inBoundsAttr);
4522 (inBounds && !inBounds.value().empty())
4525 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4526 build(builder, result, vector, dest, indices, permutationMapAttr,
4527 Value(), inBoundsAttr);
4535 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4537 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4538 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4554 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4559 if (types.size() != 2)
4560 return parser.
emitError(typesLoc,
"requires two types");
4562 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4564 return parser.
emitError(typesLoc,
"requires vector type");
4565 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4566 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4567 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4568 auto permMapAttrName =
4569 TransferWriteOp::getPermutationMapAttrName(result.
name);
4576 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4578 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4580 if (!inBoundsAttr) {
4589 if (hasMask.succeeded()) {
4590 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4592 maskInfo.
location,
"does not support masks with vector element type");
4595 "expected the same rank for the vector and the "
4596 "results of the permutation map");
4602 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4604 {1, 1, static_cast<int32_t>(indexInfo.size()),
4605 static_cast<int32_t>(hasMask.succeeded())}));
4606 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4611 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4613 p <<
", " << getMask();
4620 ShapedType shapedType = getShapedType();
4622 VectorType maskType = getMaskType();
4623 auto permutationMap = getPermutationMap();
4624 VectorType inferredMaskType =
4628 if (llvm::size(
getIndices()) != shapedType.getRank())
4629 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4633 if (hasBroadcastDim())
4634 return emitOpError(
"should not have broadcast dimensions");
4636 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4637 shapedType, vectorType, maskType,
4638 inferredMaskType, permutationMap, getInBounds())))
4642 [&](Twine t) {
return emitOpError(t); });
4649 Type TransferWriteOp::getExpectedMaskType() {
4670 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4674 if (write.getTransferRank() == 0)
4676 auto rankedTensorType =
4677 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4679 if (!rankedTensorType)
4682 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4686 if (read.getTransferRank() == 0)
4689 if (!read.getPermutationMap().isMinorIdentity() ||
4690 !write.getPermutationMap().isMinorIdentity())
4693 if (read.getTransferRank() != write.getTransferRank())
4696 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4699 if (read.getSource().getType() != rankedTensorType)
4702 if (read.getVectorType() != write.getVectorType())
4705 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4708 auto isNotConstantZero = [](
Value v) {
4710 return !cstOp.has_value() || cstOp.value() != 0;
4712 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4713 llvm::any_of(write.getIndices(), isNotConstantZero))
4716 results.push_back(read.getSource());
4720 static bool checkSameValueWAR(vector::TransferReadOp read,
4721 vector::TransferWriteOp write) {
4722 return read.getSource() == write.getSource() &&
4723 read.getIndices() == write.getIndices() &&
4724 read.getPermutationMap() == write.getPermutationMap() &&
4725 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4742 static LogicalResult foldWAR(TransferWriteOp write,
4744 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4746 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4750 if (!checkSameValueWAR(read, write))
4752 results.push_back(read.getSource());
4756 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4758 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4760 if (succeeded(foldWAR(*
this, results)))
4769 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4773 void TransferWriteOp::getEffects(
4776 if (llvm::isa<MemRefType>(getShapedType()))
4782 if (hasPureTensorSemantics())
4817 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4819 vector::TransferWriteOp writeToModify = writeOp;
4822 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4826 writeToModify.getSourceMutable().assign(defWrite.getSource());
4831 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4832 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4836 if (!defWrite->hasOneUse())
4838 writeToModify = defWrite;
4839 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4868 struct SwapExtractSliceOfTransferWrite
4875 if (!insertOp.hasUnitStride())
4878 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4879 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4881 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4882 if (!transferOp || !transferOp->hasOneUse())
4887 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4889 "use-def chain is rank-reducing");
4893 if (!extractOp.hasZeroOffset()) {
4895 "ExtractSliceOp has non-zero offset");
4899 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
4903 "TranferWriteOp has non-zero offset");
4907 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4909 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
4912 for (
auto [insertSize, extractSize] :
4913 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4916 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
4921 assert(transferOp.getVectorType().hasStaticShape() &&
4922 "expected vector to have a static shape");
4925 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4926 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
4928 insertOp,
"TransferWriteOp may not write the full tensor.");
4934 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
4935 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4936 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4937 insertOp.getMixedStrides());
4938 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
4939 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4940 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4943 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4953 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4960 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
4962 MemRefType memRefTy) {
4965 if (!vecTy.isScalable() &&
4966 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
4970 return op->
emitOpError(
"most minor memref dim must have unit stride");
4978 if (failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
4982 Type memElemTy = memRefTy.getElementType();
4983 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4984 if (memVecTy != resVecTy)
4985 return emitOpError(
"base memref and result vector types should match");
4986 memElemTy = memVecTy.getElementType();
4989 if (resVecTy.getElementType() != memElemTy)
4990 return emitOpError(
"base and result element types should match");
4991 if (llvm::size(
getIndices()) != memRefTy.getRank())
4992 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5010 if (failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5014 Type memElemTy = memRefTy.getElementType();
5015 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5016 if (memVecTy != valueVecTy)
5018 "base memref and valueToStore vector types should match");
5019 memElemTy = memVecTy.getElementType();
5022 if (valueVecTy.getElementType() != memElemTy)
5023 return emitOpError(
"base and valueToStore element type should match");
5024 if (llvm::size(
getIndices()) != memRefTy.getRank())
5025 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5029 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5039 VectorType maskVType = getMaskVectorType();
5040 VectorType passVType = getPassThruVectorType();
5044 if (resVType.getElementType() != memType.getElementType())
5045 return emitOpError(
"base and result element type should match");
5046 if (llvm::size(
getIndices()) != memType.getRank())
5047 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5048 if (resVType.getShape() != maskVType.getShape())
5049 return emitOpError(
"expected result shape to match mask shape");
5050 if (resVType != passVType)
5051 return emitOpError(
"expected pass_thru of same type as result type");
5064 load, load.getType(), load.getBase(), load.getIndices());
5067 rewriter.
replaceOp(load, load.getPassThru());
5072 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5079 results.
add<MaskedLoadFolder>(context);
5093 VectorType maskVType = getMaskVectorType();
5097 if (valueVType.getElementType() != memType.getElementType())
5098 return emitOpError(
"base and valueToStore element type should match");
5099 if (llvm::size(
getIndices()) != memType.getRank())
5100 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5101 if (valueVType.getShape() != maskVType.getShape())
5102 return emitOpError(
"expected valueToStore shape to match mask shape");
5115 store, store.getValueToStore(), store.getBase(), store.getIndices());
5123 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5130 results.
add<MaskedStoreFolder>(context);
5133 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5143 VectorType indVType = getIndexVectorType();
5144 VectorType maskVType = getMaskVectorType();
5146 ShapedType baseType = getBaseType();
5148 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5149 return emitOpError(
"requires base to be a memref or ranked tensor type");
5151 if (resVType.getElementType() != baseType.getElementType())
5152 return emitOpError(
"base and result element type should match");
5153 if (llvm::size(
getIndices()) != baseType.getRank())
5154 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5155 if (resVType.getShape() != indVType.getShape())
5156 return emitOpError(
"expected result dim to match indices dim");
5157 if (resVType.getShape() != maskVType.getShape())
5158 return emitOpError(
"expected result dim to match mask dim");
5159 if (resVType != getPassThruVectorType())
5160 return emitOpError(
"expected pass_thru of same type as result type");
5168 Type GatherOp::getExpectedMaskType() {
5169 auto vecType = this->getIndexVectorType();
5172 vecType.getScalableDims());
5175 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5189 rewriter.
replaceOp(gather, gather.getPassThru());
5194 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5201 results.
add<GatherFolder>(context);
5209 VectorType indVType = getIndexVectorType();
5210 VectorType maskVType = getMaskVectorType();
5214 if (valueVType.getElementType() != memType.getElementType())
5215 return emitOpError(
"base and valueToStore element type should match");
5216 if (llvm::size(
getIndices()) != memType.getRank())
5217 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5218 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5219 return emitOpError(
"expected valueToStore dim to match indices dim");
5220 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5221 return emitOpError(
"expected valueToStore dim to match mask dim");
5240 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5247 results.
add<ScatterFolder>(context);
5255 VectorType maskVType = getMaskVectorType();
5256 VectorType passVType = getPassThruVectorType();
5260 if (resVType.getElementType() != memType.getElementType())
5261 return emitOpError(
"base and result element type should match");
5262 if (llvm::size(
getIndices()) != memType.getRank())
5263 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5264 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5265 return emitOpError(
"expected result dim to match mask dim");
5266 if (resVType != passVType)
5267 return emitOpError(
"expected pass_thru of same type as result type");
5280 expand, expand.getType(), expand.getBase(), expand.getIndices());
5283 rewriter.
replaceOp(expand, expand.getPassThru());
5288 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5295 results.
add<ExpandLoadFolder>(context);
5303 VectorType maskVType = getMaskVectorType();
5307 if (valueVType.getElementType() != memType.getElementType())
5308 return emitOpError(
"base and valueToStore element type should match");
5309 if (llvm::size(
getIndices()) != memType.getRank())
5310 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5311 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5312 return emitOpError(
"expected valueToStore dim to match mask dim");
5317 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5325 compress, compress.getValueToStore(), compress.getBase(),
5326 compress.getIndices());
5334 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5341 results.
add<CompressStoreFolder>(context);
5350 setResultRanges(getResult(), argRanges.front());
5356 unsigned rankA = a.size();
5357 unsigned rankB = b.size();
5358 assert(rankA < rankB);
5360 auto isOne = [](int64_t v) {
return v == 1; };
5364 if (rankA == 0 && llvm::all_of(b, isOne))
5369 while (i < rankA &&
j < rankB) {
5370 int64_t dimA = a[i];
5372 while (dimB < dimA &&
j < rankB)
5380 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5382 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5386 return i == rankA &&
j == rankB;
5389 static LogicalResult verifyVectorShapeCast(
Operation *op,
5390 VectorType sourceVectorType,
5391 VectorType resultVectorType) {
5393 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5394 return op->
emitOpError(
"source/result vectors must have same element type");
5395 auto sourceShape = sourceVectorType.getShape();
5396 auto resultShape = resultVectorType.getShape();
5399 int64_t sourceDimProduct = std::accumulate(
5400 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5401 int64_t resultDimProduct = std::accumulate(
5402 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5403 if (sourceDimProduct != resultDimProduct)
5404 return op->
emitOpError(
"source/result number of elements must match");
5407 unsigned sourceRank = sourceVectorType.getRank();
5408 unsigned resultRank = resultVectorType.getRank();
5409 if (sourceRank < resultRank) {
5410 if (!isValidShapeCast(sourceShape, resultShape))
5412 }
else if (sourceRank > resultRank) {
5413 if (!isValidShapeCast(resultShape, sourceShape))
5418 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5419 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5420 if (sourceNScalableDims != resultNScalableDims)
5421 return op->
emitOpError(
"different number of scalable dims at source (")
5422 << sourceNScalableDims <<
") and result (" << resultNScalableDims
5424 sourceVectorType.getNumDynamicDims();
5430 auto sourceVectorType =
5431 llvm::dyn_cast_or_null<VectorType>(getSource().
getType());
5432 auto resultVectorType =
5433 llvm::dyn_cast_or_null<VectorType>(getResult().
getType());
5436 if (sourceVectorType && resultVectorType)
5437 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5448 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5449 if (getResult().
getType() == otherOp.getSource().getType())
5450 return otherOp.getSource();
5453 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5454 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
5455 if (srcType.getRank() < resultType.getRank()) {
5456 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5458 }
else if (srcType.getRank() > resultType.getRank()) {
5459 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5465 setOperand(otherOp.getSource());
5470 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5471 if (bcastOp.getSourceType() ==
getType())
5472 return bcastOp.getSource();
5480 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5487 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5491 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5507 static VectorType trimTrailingOneDims(VectorType oldType) {
5514 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5515 newShape = newShape.drop_back(1);
5516 newScalableDims = newScalableDims.drop_back(1);
5521 if (newShape.empty()) {
5522 newShape = oldShape.take_back();
5523 newScalableDims = oldScalableDims.take_back();
5526 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5541 class ShapeCastCreateMaskFolderTrailingOneDim final
5548 Value shapeOpSrc = shapeOp->getOperand(0);
5549 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5550 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5551 if (!createMaskOp && !constantMaskOp)
5554 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5555 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5557 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5558 if (newVecType != shapeOpResTy)
5561 auto numDimsToDrop =
5562 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5569 auto maskOperands = createMaskOp.getOperands();
5570 auto numMaskOperands = maskOperands.size();
5573 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5575 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5576 if (!constant || (constant.value() != 1))
5580 maskOperands.drop_back(numDimsToDrop);
5587 if (constantMaskOp) {
5588 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5589 auto numMaskOperands = maskDimSizes.size();
5592 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5594 if (maskDimSizes[i] != 1)
5598 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5613 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5620 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5625 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5626 broadcastSourceShape = srcType.getShape();
5628 shapeCastOp.getResultVectorType().getShape();
5632 if (broadcastSourceShape ==
5633 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5635 shapeCastOp, shapeCastOp.getResultVectorType(),
5636 broadcastOp.getSource());
5642 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5643 if (srcType.getNumElements() ==
5644 shapeCastOp.getResultVectorType().getNumElements()) {
5646 shapeCastOp, shapeCastOp.getResultVectorType(),
5647 broadcastOp.getSource());
5660 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5661 ShapeCastBroadcastFolder>(context);
5669 auto sourceVectorType = getSourceVectorType();
5670 auto resultVectorType = getResultVectorType();
5672 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5673 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5674 return emitOpError(
"dimension size mismatch at: ") << i;
5677 DataLayout dataLayout = DataLayout::closest(*
this);
5678 auto sourceElementBits =
5680 auto resultElementBits =
5683 if (sourceVectorType.getRank() == 0) {
5684 if (sourceElementBits != resultElementBits)
5685 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5686 "types must be equal");
5687 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5688 resultElementBits * resultVectorType.getShape().back()) {
5690 "source/result bitwidth of the minor 1-D vectors must be equal");
5702 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5703 if (getResult().
getType() == otherOp.getSource().getType())
5704 return otherOp.getSource();
5706 setOperand(otherOp.getSource());
5710 Attribute sourceConstant = adaptor.getSource();
5711 if (!sourceConstant)
5714 Type srcElemType = getSourceVectorType().getElementType();
5715 Type dstElemType = getResultVectorType().getElementType();
5717 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5718 if (floatPack.isSplat()) {
5719 auto splat = floatPack.getSplatValue<FloatAttr>();
5722 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5723 uint32_t bits =
static_cast<uint32_t
>(
5724 splat.getValue().bitcastToAPInt().getZExtValue());
5726 bits = (bits << 16) | (bits & 0xffff);
5727 APInt intBits(32, bits);
5728 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5734 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5735 if (intPack.isSplat()) {
5736 auto splat = intPack.getSplatValue<IntegerAttr>();
5738 if (llvm::isa<IntegerType>(dstElemType)) {
5743 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5744 APInt intBits = splat.getValue().zext(dstBitWidth);
5747 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5748 intBits = (intBits << srcBitWidth) | intBits;
5763 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5766 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5775 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5776 VectorType vectorType =
5780 memRefType.getMemorySpace()));
5785 if (!canonicalType.getLayout().isIdentity())
5786 return emitOpError(
"expects operand to be a memref with identity layout");
5787 if (!getResultMemRefType().getLayout().isIdentity())
5788 return emitOpError(
"expects result to be a memref with identity layout");
5789 if (getResultMemRefType().getMemorySpace() !=
5791 return emitOpError(
"expects result in same memory space");
5794 auto resultType = getResultMemRefType();
5798 "expects result and operand with same underlying scalar type: ")
5800 if (extractShape(sourceType) != extractShape(resultType))
5802 "expects concatenated result and operand shapes to be equal: ")
5813 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5816 for (
unsigned i = 0; i < permutation.size(); ++i) {
5817 transposedShape[i] = vt.getShape()[permutation[i]];
5818 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5823 transposedScalableDims));
5828 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5831 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5833 return attr.reshape(getResultVectorType());
5841 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5850 VectorType vectorType = getSourceVectorType();
5851 VectorType resultType = getResultVectorType();
5852 int64_t rank = resultType.getRank();
5853 if (vectorType.getRank() != rank)
5854 return emitOpError(
"vector result rank mismatch: ") << rank;
5857 int64_t size = perm.size();
5859 return emitOpError(
"transposition length mismatch: ") << size;
5862 if (ta.value() < 0 || ta.value() >= rank)
5863 return emitOpError(
"transposition index out of range: ") << ta.value();
5864 if (seen[ta.value()])
5865 return emitOpError(
"duplicate position index: ") << ta.value();
5866 seen[ta.value()] =
true;
5867 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5868 return emitOpError(
"dimension size mismatch at: ") << ta.value();
5873 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5874 return llvm::to_vector<4>(getResultVectorType().
getShape());
5880 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
5890 for (
auto index : permutation2)
5891 result.push_back(permutation1[index]);
5896 vector::TransposeOp parentTransposeOp =
5897 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5898 if (!parentTransposeOp)
5902 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5905 transposeOp, transposeOp.getResult().getType(),
5906 parentTransposeOp.getVector(), permutation);
5912 struct FoldTransposedScalarBroadcast final
5918 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5922 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5923 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5925 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5940 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5945 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5951 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
5957 Value transposeSrc = transpOp.getVector();
5958 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
5959 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
5960 if (!createMaskOp && !constantMaskOp)
5968 auto maskOperands = createMaskOp.getOperands();
5973 transpOp, transpOp.getResultVectorType(), newOperands);
5978 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5982 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
5989 void vector::TransposeOp::getCanonicalizationPatterns(
5991 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5992 TransposeFolder, FoldTransposeSplat>(context);
6001 assert(kind == ConstantMaskKind::AllTrue ||
6002 kind == ConstantMaskKind::AllFalse);
6003 build(builder, result, type,
6004 kind == ConstantMaskKind::AllTrue
6010 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6012 if (resultType.getRank() == 0) {
6013 if (getMaskDimSizes().size() != 1)
6014 return emitError(
"array attr must have length 1 for 0-D vectors");
6015 auto dim = getMaskDimSizes()[0];
6016 if (dim != 0 && dim != 1)
6017 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6022 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6024 "must specify array attr of size equal vector result rank");
6027 auto resultShape = resultType.getShape();
6028 auto resultScalableDims = resultType.getScalableDims();
6030 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6031 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6033 "array attr of size out of bounds of vector result dimension size");
6034 if (resultScalableDims[index] && maskDimSize != 0 &&
6035 maskDimSize != resultShape[index])
6037 "only supports 'none set' or 'all set' scalable dimensions");
6041 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6042 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6043 if (anyZeros && !allZeros)
6044 return emitOpError(
"expected all mask dim sizes to be zeros, "
6045 "as a result of conjunction with zero mask dim");
6049 bool ConstantMaskOp::isAllOnesMask() {
6052 if (resultType.getRank() == 0) {
6053 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6054 return getMaskDimSizes()[0] == 1;
6056 for (
const auto [resultSize, maskDimSize] :
6057 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6058 if (maskDimSize < resultSize)
6073 build(builder, result, type, operands);
6077 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6079 if (vectorType.getRank() == 0) {
6080 if (getNumOperands() != 1)
6082 "must specify exactly one operand for 0-D create_mask");
6083 }
else if (getNumOperands() !=
6084 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6086 "must specify an operand for each result vector dimension");
6122 VectorType maskType = createMaskOp.getVectorType();
6124 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6127 constexpr std::array<int64_t, 1> rankZeroShape{1};
6128 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6129 if (maskType.getRank() == 0) {
6130 maskTypeDimSizes = rankZeroShape;
6131 maskTypeDimScalableFlags = rankZeroScalableDims;
6137 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6142 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6144 constantDims.push_back(*intSize);
6148 if (vscaleMultiplier < maskTypeDimSizes[i])
6150 constantDims.push_back(*vscaleMultiplier);
6157 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6158 value = std::clamp<int64_t>(value, 0, maskDimSize);
6161 if (llvm::is_contained(constantDims, 0))
6162 constantDims.assign(constantDims.size(), 0);
6175 results.
add<CreateMaskFolder>(context);
6186 assert(maskRegionBuilder &&
6187 "builder callback for 'maskRegion' must be present");
6193 maskRegionBuilder(builder, maskableOp);
6200 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6208 build(builder, result, mask, maskableOp, maskRegionBuilder);
6229 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6236 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6250 result.
types.append(resultTypes);
6256 if (parsePassthru.succeeded())
6264 p <<
" " << getMask();
6266 p <<
", " << getPassthru();
6270 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6277 p <<
" : " << getMask().getType();
6278 if (getNumResults() > 0)
6279 p <<
" -> " << getResultTypes();
6284 MaskOp>::ensureTerminator(region, builder, loc);
6296 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
6299 if (maskedOp == oldYieldOp)
6302 opBuilder.setInsertionPoint(oldYieldOp);
6303 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
6305 oldYieldOp->
erase();
6310 Block &block = getMaskRegion().getBlocks().
front();
6312 return emitOpError(
"expects a terminator within the mask region");
6315 if (numMaskRegionOps > 2)
6316 return emitOpError(
"expects only one operation to mask");
6319 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6321 return emitOpError(
"expects a terminator within the mask region");
6323 if (terminator->getNumOperands() != getNumResults())
6325 "expects number of results to match mask region yielded values");
6328 if (numMaskRegionOps == 1)
6331 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6333 return emitOpError(
"expects a MaskableOpInterface within the mask region");
6337 return emitOpError(
"expects number of results to match maskable operation "
6338 "number of results");
6340 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6342 "expects result type to match maskable operation result type");
6345 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6346 return emitOpError(
"multiple vector results not supported");
6349 Type expectedMaskType = maskableOp.getExpectedMaskType();
6350 if (getMask().
getType() != expectedMaskType)
6351 return emitOpError(
"expects a ")
6352 << expectedMaskType <<
" mask for the maskable operation";
6355 Value passthru = getPassthru();
6357 if (!maskableOp.supportsPassthru())
6359 "doesn't expect a passthru argument for this maskable operation");
6362 return emitOpError(
"expects result when passthru argument is provided");
6365 return emitOpError(
"expects passthru type to match result type");
6372 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6382 Operation *maskableOp = getMaskableOp();
6386 llvm::append_range(results, maskableOp->
getResults());
6398 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6399 if (maskingOp.getMaskableOp())
6402 if (!maskOp.isEmpty())
6405 Block *block = maskOp.getMaskBlock();
6406 auto terminator = cast<vector::YieldOp>(block->
front());
6407 if (terminator.getNumOperands() == 0)
6410 rewriter.
replaceOp(maskOp, terminator.getOperands());
6418 results.
add<ElideEmptyMaskOp>(context);
6425 Block *block = getMaskBlock();
6429 return &block->
front();
6433 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
6440 VectorType srcType = getSourceType();
6441 VectorType initialType = getInitialValueType();
6443 int64_t srcRank = srcType.getRank();
6444 int64_t reductionDim = getReductionDim();
6445 if (reductionDim >= srcRank)
6446 return emitOpError(
"reduction dimension ")
6447 << reductionDim <<
" has to be less than " << srcRank;
6450 int64_t initialValueRank = initialType.getRank();
6451 if (initialValueRank != srcRank - 1)
6452 return emitOpError(
"initial value rank ")
6453 << initialValueRank <<
" has to be equal to " << srcRank - 1;
6459 for (
int i = 0; i < srcRank; i++) {
6460 if (i != reductionDim)
6461 expectedShape.push_back(srcShape[i]);
6463 if (!llvm::equal(initialValueShapes, expectedShape)) {
6464 return emitOpError(
"incompatible input/initial value shapes");
6468 Type eltType = getDestType().getElementType();
6470 return emitOpError(
"unsupported reduction type ")
6471 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
6480 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6481 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6482 StridedSliceConstantMaskFolder, TransposeFolder>(
6491 auto constOperand = adaptor.getInput();
6492 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6501 setResultRanges(getResult(), argRanges.front());
6506 arith::FastMathFlagsAttr fastmath,
6513 case CombiningKind::ADD:
6516 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6517 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6519 llvm_unreachable(
"invalid value types for ADD reduction");
6521 case CombiningKind::AND:
6525 case CombiningKind::MAXNUMF:
6526 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6527 "expected float values");
6528 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6530 case CombiningKind::MAXIMUMF:
6531 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6532 "expected float values");
6533 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6535 case CombiningKind::MINNUMF:
6536 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6537 "expected float values");
6538 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6540 case CombiningKind::MINIMUMF:
6541 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6542 "expected float values");
6543 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6545 case CombiningKind::MAXSI:
6549 case CombiningKind::MINSI:
6553 case CombiningKind::MAXUI:
6561 case CombiningKind::MUL:
6564 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6565 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6567 llvm_unreachable(
"invalid value types for MUL reduction");
6569 case CombiningKind::OR:
6573 case CombiningKind::XOR:
6579 assert(result &&
"unknown CombiningKind");
6591 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
6611 return builder.
create<MaskOp>(maskableOp->getLoc(),
6612 maskableOp->getResultTypes(), mask, maskableOp,
6629 mask, newValue, passthru);
6636 #define GET_ATTRDEF_CLASSES
6637 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6639 #define GET_OP_CLASSES
6640 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.