39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 if (maskIdx < dimSize)
109 auto maskOperands = m.getOperands();
110 for (
Value operand : maskOperands) {
111 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
126 builder.
create<vector::YieldOp>(loc);
132 switch (combiningKind) {
133 case CombiningKind::ADD:
134 case CombiningKind::MUL:
137 case CombiningKind::MINSI:
138 case CombiningKind::MAXUI:
139 case CombiningKind::MAXSI:
140 case CombiningKind::AND:
141 case CombiningKind::OR:
142 case CombiningKind::XOR:
144 case CombiningKind::MINNUMF:
145 case CombiningKind::MAXNUMF:
146 case CombiningKind::MINIMUMF:
147 case CombiningKind::MAXIMUMF:
148 return llvm::isa<FloatType>(elementType);
154 VectorType vectorType) {
155 int64_t elementVectorRank = 0;
156 VectorType elementVectorType =
157 llvm::dyn_cast<VectorType>(shapedType.getElementType());
158 if (elementVectorType)
159 elementVectorRank += elementVectorType.getRank();
162 if (shapedType.getRank() == 0 &&
168 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169 shapedType.getContext());
176 vector::TransferReadOp read) {
177 auto readMask = read.getMask();
178 auto writeMask = write.getMask();
184 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185 if (!couldBeSameSplat)
190 m_Constant<DenseElementsAttr>(&splatAttr)) ||
202 vector::TransferReadOp read) {
203 return !defWrite.hasOutOfBoundsDim() &&
204 defWrite.getIndices() == read.getIndices() &&
205 defWrite.getVectorType() == read.getVectorType() &&
206 defWrite.getPermutationMap() == read.getPermutationMap() &&
207 ((!defWrite.getMask() && !read.getMask()) ||
212 vector::TransferWriteOp priorWrite) {
213 return priorWrite.getIndices() == write.getIndices() &&
214 priorWrite.getMask() == write.getMask() &&
215 priorWrite.getVectorType() == write.getVectorType() &&
216 priorWrite.getPermutationMap() == write.getPermutationMap();
220 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221 bool testDynamicValueUsingBounds) {
223 if (transferA.getVectorType() != transferB.getVectorType())
225 unsigned rankOffset = transferA.getLeadingShapedRank();
226 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227 Value indexA = transferA.getIndices()[i];
228 Value indexB = transferB.getIndices()[i];
232 if (i < rankOffset) {
235 if (cstIndexA.has_value() && cstIndexB.has_value()) {
236 if (*cstIndexA != *cstIndexB)
240 if (testDynamicValueUsingBounds) {
243 FailureOr<uint64_t> delta =
245 if (succeeded(delta) && *delta != 0)
248 FailureOr<bool> testEqual =
250 if (succeeded(testEqual) && !testEqual.value())
256 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257 if (cstIndexA.has_value() && cstIndexB.has_value()) {
258 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
259 if (distance >= vectorDim)
263 if (testDynamicValueUsingBounds) {
266 FailureOr<int64_t> delta =
268 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
271 FailureOr<int64_t> computeDelta =
273 if (succeeded(computeDelta)) {
274 if (
std::abs(computeDelta.value()) >= vectorDim)
284 VectorTransferOpInterface transferB,
285 bool testDynamicValueUsingBounds) {
286 if (transferA.getSource() != transferB.getSource())
289 testDynamicValueUsingBounds);
299 for (
auto [posInDim, dimSize, offsetInDim] :
300 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302 if (posInDim < dimSize + offsetInDim)
306 posInDim = offsetInDim;
316 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
318 assert(constOp &&
"Unexpected non-constant index");
319 return constOp.value();
329 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
330 assert(foldResult.is<
Attribute>() &&
"Unexpected non-constant index");
331 return cast<IntegerAttr>(foldResult.get<
Attribute>()).getInt();
341 llvm::transform(foldResults, std::back_inserter(values),
343 if (
auto attr = foldResult.dyn_cast<
Attribute>())
346 loc, cast<IntegerAttr>(attr).getInt())
349 return foldResult.get<
Value>();
360 auto lhs = mul.getLhs();
361 auto rhs = mul.getRhs();
362 if (lhs.getDefiningOp<vector::VectorScaleOp>())
364 if (rhs.getDefiningOp<vector::VectorScaleOp>())
412 void VectorDialect::initialize() {
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
423 addInterfaces<VectorInlinerInterface>();
425 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
428 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
430 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
439 return arith::ConstantOp::materialize(builder, value, type, loc);
455 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
458 CombiningKind kind) {
462 reductionDims.push_back(en.index());
463 build(builder, result, kind, source, acc, reductionDims);
466 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
468 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
473 std::optional<SmallVector<int64_t, 4>>
474 MultiDimReductionOp::getShapeForUnroll() {
475 return llvm::to_vector<4>(getSourceVectorType().
getShape());
481 Type inferredReturnType;
482 auto sourceScalableDims = getSourceVectorType().getScalableDims();
483 for (
auto [dimIdx, dimSize] :
485 if (!llvm::any_of(getReductionDims(),
486 [dimIdx = dimIdx](int64_t reductionDimIdx) {
487 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
489 targetShape.push_back(dimSize);
490 scalableDims.push_back(sourceScalableDims[dimIdx]);
493 if (targetShape.empty())
494 inferredReturnType = getSourceVectorType().getElementType();
497 targetShape, getSourceVectorType().
getElementType(), scalableDims);
498 if (
getType() != inferredReturnType)
499 return emitOpError() <<
"destination type " <<
getType()
500 <<
" is incompatible with source type "
501 << getSourceVectorType();
507 Type MultiDimReductionOp::getExpectedMaskType() {
508 auto vecType = getSourceVectorType();
511 vecType.getScalableDims());
520 struct ElideUnitDimsInMultiDimReduction
524 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
527 for (
const auto &dim :
enumerate(shape)) {
528 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
536 if (reductionOp.isMasked()) {
538 rootOp = reductionOp.getMaskingOp();
539 mask = reductionOp.getMaskingOp().getMask();
541 rootOp = reductionOp;
544 Location loc = reductionOp.getLoc();
545 Value acc = reductionOp.getAcc();
547 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
549 VectorType newMaskType =
551 dstVecType.getScalableDims());
552 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
554 cast = rewriter.
create<vector::ShapeCastOp>(
555 loc, reductionOp.getDestType(), reductionOp.getSource());
561 mask = rewriter.
create<vector::ExtractOp>(loc, mask, zeroIdx);
562 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource(),
568 cast,
nullptr, mask);
575 void MultiDimReductionOp::getCanonicalizationPatterns(
577 results.
add<ElideUnitDimsInMultiDimReduction>(context);
585 CombiningKind kind,
Value vector,
586 arith::FastMathFlags fastMathFlags) {
587 build(builder, result, kind, vector,
Value(), fastMathFlags);
592 arith::FastMathFlags fastMathFlags) {
593 build(builder, result,
594 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
600 int64_t rank = getSourceVectorType().getRank();
602 return emitOpError(
"unsupported reduction rank: ") << rank;
605 Type eltType = getDest().getType();
607 return emitOpError(
"unsupported reduction type '")
608 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
617 Type ReductionOp::getExpectedMaskType() {
618 auto vecType = getSourceVectorType();
621 vecType.getScalableDims());
628 case arith::AtomicRMWKind::addf:
629 case arith::AtomicRMWKind::addi:
630 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
631 CombiningKind::ADD, vector);
632 case arith::AtomicRMWKind::mulf:
633 case arith::AtomicRMWKind::muli:
634 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
635 CombiningKind::MUL, vector);
636 case arith::AtomicRMWKind::minimumf:
637 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
638 CombiningKind::MINIMUMF, vector);
639 case arith::AtomicRMWKind::mins:
640 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
641 CombiningKind::MINSI, vector);
642 case arith::AtomicRMWKind::minu:
643 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
645 case arith::AtomicRMWKind::maximumf:
646 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
647 CombiningKind::MAXIMUMF, vector);
648 case arith::AtomicRMWKind::maxs:
649 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
650 CombiningKind::MAXSI, vector);
651 case arith::AtomicRMWKind::maxu:
652 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
653 CombiningKind::MAXUI, vector);
654 case arith::AtomicRMWKind::andi:
655 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
656 CombiningKind::AND, vector);
657 case arith::AtomicRMWKind::ori:
658 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
659 CombiningKind::OR, vector);
668 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
669 return llvm::to_vector<4>(getSourceVectorType().
getShape());
676 LogicalResult matchAndRewrite(ReductionOp reductionOp,
681 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
684 if (maskableOp.isMasked()) {
686 rootOp = maskableOp.getMaskingOp();
687 mask = maskableOp.getMaskingOp().getMask();
689 rootOp = reductionOp;
692 auto vectorType = reductionOp.getSourceVectorType();
693 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
696 Location loc = reductionOp.getLoc();
698 if (vectorType.getRank() == 0) {
700 mask = rewriter.
create<ExtractElementOp>(loc, mask);
701 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
704 mask = rewriter.
create<ExtractOp>(loc, mask, 0);
705 result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector(), 0);
708 if (
Value acc = reductionOp.getAcc())
711 reductionOp.getFastmathAttr(), mask);
721 results.
add<ElideSingleElementReduction>(context);
735 getIndexingMapsAttrName(result.
name),
739 getIteratorTypesAttrName(result.
name),
742 return IteratorTypeAttr::get(builder.getContext(), t);
748 ArrayAttr indexingMaps,
749 ArrayAttr iteratorTypes) {
750 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
751 ContractionOp::getDefaultKind());
756 ArrayAttr indexingMaps,
757 ArrayAttr iteratorTypes, CombiningKind kind) {
774 DictionaryAttr dictAttr;
789 dictAttr.getValue().end());
795 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
800 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
801 auto maybeIteratorType = symbolizeIteratorType(s);
802 if (!maybeIteratorType.has_value())
803 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
805 iteratorTypeAttrs.push_back(
813 getKindAttrName(result.
name),
815 ContractionOp::getDefaultKind()));
817 if (masksInfo.empty())
819 if (masksInfo.size() != 2)
821 "expected zero or exactly 2 vector mask operands");
822 auto lhsType = llvm::cast<VectorType>(types[0]);
823 auto rhsType = llvm::cast<VectorType>(types[1]);
825 std::array<VectorType, 2> maskTypes = {
835 auto attrNames = getTraitAttrNames();
837 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
839 for (
auto attr : (*this)->getAttrs()) {
840 if (attr.getName() == getIteratorTypesAttrName()) {
842 llvm::cast<ArrayAttr>(attr.getValue())
843 .getAsValueRange<IteratorTypeAttr, IteratorType>();
849 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
853 attrs.emplace_back(getIteratorTypesAttrName(),
855 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
856 attrs.push_back(attr);
860 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
861 p << getRhs() <<
", " << getAcc();
864 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
869 const std::vector<std::pair<int64_t, int64_t>> &map) {
870 for (
auto &dimPair : map) {
871 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
872 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
873 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
880 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
882 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
883 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
886 for (
auto &dimPair : contractingDimMap) {
887 lhsContractingDimSet.insert(dimPair.first);
888 rhsContractingDimSet.insert(dimPair.second);
891 for (
auto &dimPair : batchDimMap)
892 rhsBatchDimSet.insert(dimPair.second);
896 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
897 if (lhsContractingDimSet.count(i) > 0)
899 expectedResultDims.push_back(lhsType.getDimSize(i));
903 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
904 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
906 expectedResultDims.push_back(rhsType.getDimSize(i));
910 if (expectedResultDims.empty()) {
912 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
913 return op.emitOpError(
"invalid accumulator/result vector shape");
916 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
917 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
918 if (!resVectorType || !accVectorType)
919 return op.emitOpError(
"invalid accumulator/result vector shape");
925 AffineMap lhsMap = op.getIndexingMapsArray()[0];
926 AffineMap rhsMap = op.getIndexingMapsArray()[1];
928 return op.emitOpError(
929 "expected all dimensions to be either a LHS or a RHS dimension");
932 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
933 VectorType v = pair.first;
934 auto map = pair.second;
935 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
936 unsigned pos = map.getDimPosition(idx);
941 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
942 return op.emitOpError(
"expected all dimensions to get an extent as "
943 "either a LHS or a RHS dimension");
945 AffineMap resMap = op.getIndexingMapsArray()[2];
951 llvm::IsaPred<AffineConstantExpr>) &&
952 "expected constant extent along all dimensions.");
954 auto expectedShape = llvm::to_vector<4>(
956 return cast<AffineConstantExpr>(e).getValue();
960 resVectorType.getScalableDims());
961 if (resVectorType != expected || accVectorType != expected)
962 return op.emitOpError(
963 "invalid accumulator/result vector shape, expected: ")
970 VectorType lhsType = getLhsType();
971 VectorType rhsType = getRhsType();
972 Type accType = getAccType();
973 Type resType = getResultType();
975 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
976 if (!lhsType.getElementType().isSignlessInteger())
977 return emitOpError(
"only supports signless integer types");
981 if (getIndexingMapsArray().size() != 3)
982 return emitOpError(
"expected an indexing map for each vector operand");
987 unsigned numIterators = getIteratorTypes().getValue().size();
989 auto index = it.index();
990 auto map = it.value();
991 if (map.getNumSymbols() != 0)
992 return emitOpError(
"expected indexing map ")
993 << index <<
" to have no symbols";
994 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
995 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
998 if (map.getNumDims() != numIterators)
999 return emitOpError(
"expected indexing map ")
1000 << index <<
" to have " << numIterators <<
" number of inputs";
1001 if (map.getNumResults() != rank)
1002 return emitOpError(
"expected indexing map ")
1003 << index <<
" to have " << rank <<
" number of outputs";
1004 if (!map.isProjectedPermutation())
1005 return emitOpError(
"expected indexing map ")
1006 << index <<
" to be a projected permutation of its inputs";
1009 auto contractingDimMap = getContractingDimMap();
1010 auto batchDimMap = getBatchDimMap();
1013 if (contractingDimMap.empty())
1014 return emitOpError(
"expected at least one contracting dimension pair");
1017 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1018 return emitOpError(
"invalid contracting dimension map");
1022 return emitOpError(
"invalid batch dimension map");
1026 contractingDimMap, batchDimMap)))
1030 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1031 auto elementType = vectorType ? vectorType.getElementType() : resType;
1033 return emitOpError(
"unsupported contraction type");
1042 Type ContractionOp::getExpectedMaskType() {
1043 auto indexingMaps = this->getIndexingMapsArray();
1046 VectorType lhsType = this->getLhsType();
1047 VectorType rhsType = this->getRhsType();
1049 unsigned numVecDims = lhsIdxMap.
getNumDims();
1058 lhsType.getScalableDims()[dimIdx];
1063 rhsType.getScalableDims()[dimIdx];
1066 assert(!ShapedType::isDynamicShape(maskShape) &&
1067 "Mask shape couldn't be computed");
1071 maskShapeScalableDims);
1076 getIteratorTypesAttrName(), getKindAttrName()};
1086 static std::vector<std::pair<int64_t, int64_t>>
1088 IteratorType targetIteratorType,
MLIRContext *context) {
1089 std::vector<std::pair<int64_t, int64_t>> dimMap;
1091 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1092 if (iteratorType != targetIteratorType)
1098 if (lhsDim >= 0 && rhsDim >= 0)
1099 dimMap.emplace_back(lhsDim, rhsDim);
1104 void ContractionOp::getIterationBounds(
1106 auto lhsShape = getLhsType().getShape();
1107 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1113 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1114 if (iteratorType == IteratorType::reduction) {
1116 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1117 assert(lhsDimIndex >= 0);
1118 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1122 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1123 assert(resDimIndex >= 0);
1124 assert(resVectorType !=
nullptr);
1125 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1129 void ContractionOp::getIterationIndexMap(
1131 unsigned numMaps = getIndexingMapsArray().size();
1132 iterationIndexMap.resize(numMaps);
1134 auto index = it.index();
1135 auto map = it.value();
1136 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1137 auto dim = cast<AffineDimExpr>(map.getResult(i));
1138 iterationIndexMap[index][dim.getPosition()] = i;
1143 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1145 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1149 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1151 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1155 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1157 getIterationBounds(shape);
1179 template <
typename AddOpType>
1185 auto canonicalize = [&](
Value maybeContraction,
1186 Value otherOperand) -> vector::ContractionOp {
1187 vector::ContractionOp contractionOp =
1188 dyn_cast_or_null<vector::ContractionOp>(
1191 return vector::ContractionOp();
1192 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1193 contractionOp.getAcc().getDefiningOp())) {
1194 if (maybeZero.getValue() ==
1195 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1197 bvm.
map(contractionOp.getAcc(), otherOperand);
1198 auto newContraction =
1199 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1200 rewriter.
replaceOp(addOp, newContraction.getResult());
1201 return newContraction;
1204 return vector::ContractionOp();
1207 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1208 vector::ContractionOp
contract = canonicalize(a, b);
1210 return contract ? success() : failure();
1226 setResultRanges(getResult(), argRanges.front());
1232 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1236 VectorType vectorType = getSourceVectorType();
1237 if (vectorType.getRank() == 0) {
1239 return emitOpError(
"expected position to be empty with 0-D vector");
1242 if (vectorType.getRank() != 1)
1243 return emitOpError(
"unexpected >1 vector rank");
1245 return emitOpError(
"expected position for 1-D vector");
1249 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1251 if (!adaptor.getPosition())
1255 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1256 return splat.getInput();
1259 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1263 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1264 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1268 auto srcElements = src.getValues<
Attribute>();
1270 uint64_t posIdx = pos.getInt();
1271 if (posIdx >= srcElements.size())
1274 return srcElements[posIdx];
1283 setResultRanges(getResult(), argRanges.front());
1287 Value source, int64_t position) {
1307 build(builder, result, source, dynamicPos,
1312 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1313 ExtractOp::Adaptor adaptor,
1315 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1316 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1317 vectorType.getRank()) {
1318 inferredReturnTypes.push_back(vectorType.getElementType());
1320 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1321 vectorType.getRank());
1323 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1324 vectorType.getScalableDims().drop_front(n)));
1332 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1333 return vectorType && vectorType.getShape().equals({1}) &&
1334 vectorType.getElementType() == r.front();
1336 if (l.size() == 1 && r.size() == 1 &&
1337 (isCompatible(l, r) || isCompatible(r, l)))
1344 auto dynamicMarkersCount =
1345 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1346 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1348 "mismatch between dynamic and static positions (kDynamic marker but no "
1349 "corresponding dynamic position) -- this can only happen due to an "
1350 "incorrect fold/rewrite");
1351 auto position = getMixedPosition();
1352 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1354 "expected position attribute of rank no greater than vector rank");
1357 int64_t constIdx = cast<IntegerAttr>(pos.get<
Attribute>()).getInt();
1358 if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1359 return emitOpError(
"expected position attribute #")
1361 <<
" to be a non-negative integer smaller than the "
1362 "corresponding vector dimension";
1369 template <
typename IntType>
1371 return llvm::to_vector<4>(llvm::map_range(
1372 arrayAttr.getAsRange<IntegerAttr>(),
1373 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1379 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1383 if (extractOp.hasDynamicPosition())
1387 ExtractOp currentOp = extractOp;
1389 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1390 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1393 if (currentOp.hasDynamicPosition())
1396 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1398 extractOp.setOperand(0, currentOp.getVector());
1401 std::reverse(globalPosition.begin(), globalPosition.end());
1402 extractOp.setStaticPosition(globalPosition);
1414 class ExtractFromInsertTransposeChainState {
1416 ExtractFromInsertTransposeChainState(ExtractOp e);
1425 template <
typename ContainerA,
typename ContainerB>
1426 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1427 return a.size() <= b.size() &&
1428 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1435 template <
typename ContainerA,
typename ContainerB>
1436 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1437 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1438 if (elemA < 0 || elemB < 0)
1453 void updateStateForNextIteration(
Value v) {
1460 LogicalResult handleTransposeOp();
1463 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1478 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1483 Value tryToFoldExtractOpInPlace(
Value source);
1485 ExtractOp extractOp;
1487 int64_t extractedRank;
1489 InsertOp nextInsertOp;
1490 TransposeOp nextTransposeOp;
1505 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1507 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1508 extractedRank(extractOp.getNumIndices()) {
1509 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1510 sentinels.reserve(vectorRank - extractedRank);
1511 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1512 sentinels.push_back(-(i + 1));
1514 extractOp.getStaticPosition().end());
1520 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1522 if (extractOp.hasDynamicPosition())
1525 if (!nextTransposeOp)
1528 nextTransposeOp.getPermutation(), extractOp.getContext()));
1535 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1538 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1545 res = nextInsertOp.getSource();
1547 return success(canFold());
1554 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1556 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1569 res = nextInsertOp.getSource();
1577 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1580 if (extractOp.hasDynamicPosition())
1584 bool nothingToFold = (source == extractOp.getVector());
1585 if (nothingToFold || !canFold())
1590 extractOp.setStaticPosition(
1592 extractOp.getVectorMutable().assign(source);
1593 return extractOp.getResult();
1597 Value ExtractFromInsertTransposeChainState::fold() {
1599 if (extractOp.hasDynamicPosition())
1602 Value valueToExtractFrom = extractOp.getVector();
1603 updateStateForNextIteration(valueToExtractFrom);
1604 while (nextInsertOp || nextTransposeOp) {
1607 if (succeeded(handleTransposeOp())) {
1608 valueToExtractFrom = nextTransposeOp.getVector();
1609 updateStateForNextIteration(valueToExtractFrom);
1615 if (succeeded(handleInsertOpWithMatchingPos(result)))
1620 if (succeeded(handleInsertOpWithPrefixPos(result)))
1621 return tryToFoldExtractOpInPlace(result);
1631 valueToExtractFrom = nextInsertOp.getDest();
1632 updateStateForNextIteration(valueToExtractFrom);
1635 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1640 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1641 auto vecType = dyn_cast<VectorType>(type);
1642 return vecType && vecType.getRank() == 0;
1652 if (extractOp.hasDynamicPosition())
1655 Operation *defOp = extractOp.getVector().getDefiningOp();
1656 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1660 if (extractOp.getType() == source.
getType())
1662 auto getRank = [](
Type type) {
1663 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1668 unsigned broadcastSrcRank = getRank(source.
getType());
1669 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1672 unsigned extractResultRank = getRank(extractOp.getType());
1673 if (extractResultRank >= broadcastSrcRank)
1676 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1677 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1678 if (extractVecType && broadcastVecType &&
1679 extractVecType.getShape() !=
1680 broadcastVecType.getShape().take_back(extractResultRank))
1683 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1684 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1690 broadcastOp.computeBroadcastedUnitDims();
1692 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1693 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1694 if (broadcastedUnitDims.contains(i))
1698 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1699 extractPos.erase(extractPos.begin(),
1700 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1703 extractOp.setOperand(0, source);
1704 extractOp.setStaticPosition(extractPos);
1705 return extractOp.getResult();
1721 if (extractOp.hasDynamicPosition())
1724 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1729 if (shuffleOp.getResultVectorType().getRank() != 1)
1732 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1733 auto shuffleMask = shuffleOp.getMask();
1734 int64_t extractIdx = extractOp.getStaticPosition()[0];
1735 int64_t shuffleIdx = shuffleMask[extractIdx];
1738 if (shuffleIdx < inputVecSize) {
1739 extractOp.setOperand(0, shuffleOp.getV1());
1740 extractOp.setStaticPosition({shuffleIdx});
1742 extractOp.setOperand(0, shuffleOp.getV2());
1743 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1746 return extractOp.getResult();
1752 if (extractOp.hasDynamicPosition())
1755 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1765 auto getDimReverse = [](VectorType type, int64_t n) {
1766 return type.getShape().take_back(n + 1).front();
1768 int64_t destinationRank =
1769 llvm::isa<VectorType>(extractOp.getType())
1770 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1772 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1774 if (destinationRank > 0) {
1775 auto destinationType =
1776 llvm::cast<VectorType>(extractOp.getResult().getType());
1777 for (int64_t i = 0; i < destinationRank; i++) {
1781 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1782 getDimReverse(destinationType, i))
1789 std::reverse(extractedPos.begin(), extractedPos.end());
1792 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1793 strides.push_back(stride);
1795 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1798 int64_t position =
linearize(extractedPos, strides);
1802 int64_t numDimension =
1803 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1805 for (int64_t i = 0; i < numDimension; i++) {
1806 newStrides.push_back(stride);
1808 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1810 std::reverse(newStrides.begin(), newStrides.end());
1814 extractOp.setStaticPosition(newPosition);
1815 extractOp.setOperand(0, shapeCastOp.getSource());
1816 return extractOp.getResult();
1822 if (extractOp.hasDynamicPosition())
1825 auto extractStridedSliceOp =
1826 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1827 if (!extractStridedSliceOp)
1836 if (extractStridedSliceOp.hasNonUnitStrides())
1841 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1842 while (!sliceOffsets.empty()) {
1843 size_t lastOffset = sliceOffsets.size() - 1;
1844 if (sliceOffsets.back() != 0 ||
1845 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1846 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1848 sliceOffsets.pop_back();
1850 unsigned destinationRank = 0;
1851 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1852 destinationRank = vecType.getRank();
1855 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1856 sliceOffsets.size())
1860 assert(extractedPos.size() >= sliceOffsets.size());
1861 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1862 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1863 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1867 extractOp.setStaticPosition(extractedPos);
1868 return extractOp.getResult();
1874 if (extractOp.hasDynamicPosition())
1877 int64_t destinationRank =
1878 llvm::isa<VectorType>(extractOp.getType())
1879 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1881 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1891 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1892 insertOp.getSourceVectorType().getRank();
1893 if (destinationRank > insertOp.getSourceVectorType().getRank())
1895 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1898 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1899 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1902 bool disjoint =
false;
1904 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1905 int64_t start = insertOffsets[dim];
1907 (dim < insertRankDiff)
1909 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1910 int64_t end = start + size;
1911 int64_t offset = extractOffsets[dim];
1913 if (start <= offset && offset < end) {
1914 if (dim >= insertRankDiff)
1915 offsetDiffs.push_back(offset - start);
1925 int64_t srcRankDiff =
1926 insertOp.getSourceVectorType().getRank() - destinationRank;
1927 for (int64_t i = 0; i < destinationRank; i++) {
1928 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1929 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1933 extractOp.getVectorMutable().assign(insertOp.getSource());
1936 extractOp.setStaticPosition(offsetDiffs);
1937 return extractOp.getResult();
1941 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1954 if (extractOp.hasDynamicPosition())
1958 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1959 if (!fromElementsOp)
1963 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1964 if (vecType.isScalable())
1968 int64_t rank = vecType.getRank();
1970 if (extractOp.getType() != vecType.getElementType())
1972 assert(
static_cast<int64_t
>(indices.size()) == rank &&
1973 "unexpected number of indices");
1978 for (
int i = rank - 1; i >= 0; --i) {
1979 flatIndex += indices[i] * stride;
1980 stride *= vecType.getDimSize(i);
1982 return fromElementsOp.getElements()[flatIndex];
1989 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
1993 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2019 Operation *defOp = extractOp.getVector().getDefiningOp();
2020 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2024 if (extractOp.getType() == source.
getType())
2026 auto getRank = [](
Type type) {
2027 return llvm::isa<VectorType>(type)
2028 ? llvm::cast<VectorType>(type).getRank()
2031 unsigned broadcastSrcRank = getRank(source.
getType());
2032 unsigned extractResultRank = getRank(extractOp.getType());
2036 if (extractResultRank < broadcastSrcRank)
2040 if (extractResultRank == 0) {
2041 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
2046 extractOp, extractOp.getType(), source);
2052 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
2060 Value sourceVector = extractOp.getVector();
2064 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2067 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2068 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2076 class ExtractOpNonSplatConstantFolder final
2084 if (extractOp.hasDynamicPosition())
2089 Value sourceVector = extractOp.getVector();
2094 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
2095 if (vecTy.isScalable())
2099 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2100 if (!dense || dense.isSplat())
2106 copy(extractOp.getStaticPosition(), completePositions.begin());
2107 int64_t elemBeginPosition =
2109 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2112 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2114 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2117 newAttr = *denseValuesBegin;
2133 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2137 VectorType extractedMaskType =
2138 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2140 if (!extractedMaskType)
2143 auto maskOperands = createMaskOp.getOperands();
2145 VectorType maskType = createMaskOp.getVectorType();
2147 bool containsUnknownDims =
false;
2150 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2152 int64_t pos = extractOpPos[dimIdx];
2153 Value operand = maskOperands[dimIdx];
2154 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2157 containsUnknownDims =
true;
2161 int64_t createMaskBound =
2162 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2164 if (pos != ShapedType::kDynamic) {
2167 allFalse |= pos >= createMaskBound;
2168 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2172 containsUnknownDims =
true;
2179 }
else if (!containsUnknownDims) {
2181 extractOp, extractedMaskType,
2182 maskOperands.drop_front(extractOpPos.size()));
2192 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2194 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2198 VectorType sourceType = castOp.getSourceVectorType();
2199 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2203 if (sourceType.getNumElements() != targetType.getNumElements())
2207 castOp.getSource());
2217 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2220 if (extractOp.hasDynamicPosition())
2224 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2229 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2230 if (!fromElementsOp)
2232 VectorType inputType = fromElementsOp.getType();
2235 if (resultType.isScalable() || inputType.isScalable())
2241 llvm::to_vector(extractOp.getStaticPosition());
2242 firstElementPos.append(resultType.getRank(), 0);
2245 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2246 flatIndex += firstElementPos[i] * stride;
2247 stride *= inputType.getDimSize(i);
2252 extractOp, resultType,
2253 fromElementsOp.getElements().slice(flatIndex,
2254 resultType.getNumElements()));
2261 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2262 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2263 results.
add(foldExtractFromShapeCastToShapeCast);
2264 results.
add(foldExtractFromFromElements);
2269 for (
auto attr : arrayAttr)
2270 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2277 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2292 if (!llvm::all_equal(fromElementsOp.getElements()))
2295 fromElementsOp.getElements().front());
2310 setResultRanges(getResult(), argRanges.front());
2318 int64_t rankDiff = dstShape.size() - srcShape.size();
2319 int64_t dstDim = rankDiff;
2321 for (
auto [s1, s2] :
2322 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2324 assert(s1 == 1 &&
"expected dim-1 broadcasting");
2334 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2353 Value BroadcastOp::createOrFoldBroadcastOp(
2356 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2360 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2361 if (broadcastedDims.contains(i))
2363 checkShape.push_back(dstShape[i]);
2365 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2366 "ill-formed broadcastedDims contains values not confined to "
2371 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2375 if (!srcVectorType) {
2376 assert(checkShape.empty() &&
2377 "ill-formed createOrFoldBroadcastOp arguments");
2378 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2381 assert(srcVectorType.getShape().equals(checkShape) &&
2382 "ill-formed createOrFoldBroadcastOp arguments");
2393 broadcastShape.reserve(dstShape.size());
2409 int64_t nextSrcShapeDim = broadcastedDims.size();
2410 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2411 if (broadcastedDims.contains(i)) {
2416 broadcastShape.push_back(dstShape[i]);
2417 permutation[i] = broadcastShape.size() - 1;
2423 permutation[i] = nextSrcShapeDim++;
2427 llvm::append_range(broadcastShape, srcVectorType.getShape());
2432 "unexpected dim-1 broadcast");
2434 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2436 vector::BroadcastableToResult::Success &&
2437 "must be broadcastable");
2441 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2442 if (permutation[i] != i)
2443 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2449 Type srcType, VectorType dstVectorType,
2450 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2454 return BroadcastableToResult::Success;
2456 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2458 return BroadcastableToResult::SourceTypeNotAVector;
2460 int64_t srcRank = srcVectorType.getRank();
2461 int64_t dstRank = dstVectorType.getRank();
2462 if (srcRank > dstRank)
2463 return BroadcastableToResult::SourceRankHigher;
2466 int64_t lead = dstRank - srcRank;
2467 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2470 bool foundMismatchingDims =
false;
2473 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2474 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2475 if (srcDim != 1 && srcDim != dstDim)
2476 foundMismatchingDims =
true;
2479 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2480 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2481 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2484 (srcDimScalableFlag != dstDimScalableFlag &&
2485 (srcDim != 1 || srcDimScalableFlag)))
2486 foundMismatchingDims =
true;
2488 if (foundMismatchingDims) {
2489 if (mismatchingDims !=
nullptr) {
2490 mismatchingDims->first.dim = srcDim;
2491 mismatchingDims->first.isScalable = srcDimScalableFlag;
2493 mismatchingDims->second.dim = dstDim;
2494 mismatchingDims->second.isScalable = dstDimScalableFlag;
2496 return BroadcastableToResult::DimensionMismatch;
2500 return BroadcastableToResult::Success;
2504 std::pair<VectorDim, VectorDim> mismatchingDims;
2506 getSourceType(), getResultVectorType(), &mismatchingDims);
2507 if (res == BroadcastableToResult::Success)
2509 if (res == BroadcastableToResult::SourceRankHigher)
2510 return emitOpError(
"source rank higher than destination rank");
2511 if (res == BroadcastableToResult::DimensionMismatch) {
2512 return emitOpError(
"dimension mismatch (")
2513 << (mismatchingDims.first.isScalable ?
"[" :
"")
2514 << mismatchingDims.first.dim
2515 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2516 << (mismatchingDims.second.isScalable ?
"[" :
"")
2517 << mismatchingDims.second.dim
2518 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2520 if (res == BroadcastableToResult::SourceTypeNotAVector)
2521 return emitOpError(
"source type is not a vector");
2522 llvm_unreachable(
"unexpected vector.broadcast op error");
2526 if (getSourceType() == getResultVectorType())
2528 if (!adaptor.getSource())
2530 auto vectorType = getResultVectorType();
2531 if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2533 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2546 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2550 broadcastOp.getResultVectorType(),
2551 srcBroadcast.getSource());
2561 results.
add<BroadcastFolder>(context);
2569 VectorType resultType = getResultVectorType();
2570 VectorType v1Type = getV1VectorType();
2571 VectorType v2Type = getV2VectorType();
2573 int64_t resRank = resultType.getRank();
2574 int64_t v1Rank = v1Type.getRank();
2575 int64_t v2Rank = v2Type.getRank();
2576 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2577 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2578 if (!wellFormed0DCase && !wellFormedNDCase)
2579 return emitOpError(
"rank mismatch");
2582 for (int64_t r = 1; r < v1Rank; ++r) {
2583 int64_t resDim = resultType.getDimSize(r);
2584 int64_t v1Dim = v1Type.getDimSize(r);
2585 int64_t v2Dim = v2Type.getDimSize(r);
2586 if (resDim != v1Dim || v1Dim != v2Dim)
2587 return emitOpError(
"dimension mismatch");
2591 int64_t maskLength = mask.size();
2592 if (maskLength <= 0)
2593 return emitOpError(
"invalid mask length");
2594 if (maskLength != resultType.getDimSize(0))
2595 return emitOpError(
"mask length mismatch");
2597 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2598 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2600 if (maskPos < 0 || maskPos >= indexSize)
2601 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2607 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2608 ShuffleOp::Adaptor adaptor,
2610 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2611 auto v1Rank = v1Type.getRank();
2615 shape.reserve(v1Rank);
2616 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2619 llvm::append_range(shape, v1Type.getShape().drop_front());
2620 inferredReturnTypes.push_back(
2625 template <
typename T>
2628 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2629 return value == expected++;
2633 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2634 VectorType v1Type = getV1VectorType();
2637 if (v1Type.getRank() == 0)
2641 if (!v1Type.isScalable() &&
2645 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2647 getV2VectorType().getDimSize(0)))
2650 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2655 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).
getType());
2658 if (lhsType.getRank() != 1)
2660 int64_t lhsSize = lhsType.getDimSize(0);
2663 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<
Attribute>();
2664 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<
Attribute>();
2665 for (int64_t i : this->getMask()) {
2667 results.push_back(rhsElements[i - lhsSize]);
2669 results.push_back(lhsElements[i]);
2685 VectorType v1VectorType = shuffleOp.getV1VectorType();
2687 if (v1VectorType.getRank() > 0)
2689 if (mask.size() != 1)
2709 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2710 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2712 if (!v1Splat || !v2Splat)
2715 if (v1Splat.getInput() != v2Splat.getInput())
2731 VectorType resultType = op.getResultVectorType();
2732 if (resultType.isScalable())
2734 op,
"ShuffleOp can't represent a scalable interleave");
2736 if (resultType.getRank() != 1)
2738 op,
"ShuffleOp can't represent an n-D interleave");
2740 VectorType sourceType = op.getV1VectorType();
2741 if (sourceType != op.getV2VectorType() ||
2742 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2744 op,
"ShuffleOp types don't match an interleave");
2748 int64_t resultVectorSize = resultType.getNumElements();
2749 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2750 int64_t maskValueA = shuffleMask[i * 2];
2751 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2752 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2754 "ShuffleOp mask not interleaving");
2766 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2776 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2781 build(builder, result, source, dest, {});
2785 auto dstVectorType = getDestVectorType();
2786 if (dstVectorType.getRank() == 0) {
2788 return emitOpError(
"expected position to be empty with 0-D vector");
2791 if (dstVectorType.getRank() != 1)
2792 return emitOpError(
"unexpected >1 vector rank");
2794 return emitOpError(
"expected position for 1-D vector");
2798 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2800 if (!adaptor.getPosition())
2803 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2804 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2805 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2806 if (!src || !dst || !pos)
2812 auto dstElements = dst.getValues<
Attribute>();
2816 uint64_t posIdx = pos.getInt();
2817 if (posIdx >= results.size())
2819 results[posIdx] = src;
2830 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2834 Value source,
Value dest, int64_t position) {
2847 posVals.reserve(position.size());
2848 llvm::transform(position, std::back_inserter(posVals),
2850 build(builder, result, source, dest, posVals);
2859 build(builder, result, source, dest, dynamicPos,
2865 auto destVectorType = getDestVectorType();
2866 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2868 "expected position attribute of rank no greater than dest vector rank");
2869 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2870 if (srcVectorType &&
2871 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2872 static_cast<unsigned>(destVectorType.getRank())))
2873 return emitOpError(
"expected position attribute rank + source rank to "
2874 "match dest vector rank");
2875 if (!srcVectorType &&
2876 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2878 "expected position attribute rank to match the dest vector rank");
2880 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2881 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2882 if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2883 return emitOpError(
"expected position attribute #")
2885 <<
" to be a non-negative integer smaller than the "
2887 "dest vector dimension";
2904 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2905 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2906 srcVecType.getNumElements())
2909 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2921 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2922 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2924 if (!srcSplat || !dstSplat)
2927 if (srcSplat.getInput() != dstSplat.getInput())
2942 static constexpr int64_t vectorSizeFoldThreshold = 256;
2947 if (op.hasDynamicPosition())
2956 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2960 VectorType destTy = destVector.getType();
2961 if (destTy.isScalable())
2965 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2966 !destVector.hasOneUse())
2969 Value sourceValue = op.getSource();
2977 copy(op.getStaticPosition(), completePositions.begin());
2978 int64_t insertBeginPosition =
2982 Type destEltType = destTy.getElementType();
2987 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2988 for (
auto value : denseSource.getValues<
Attribute>())
2994 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
2995 copy(insertedValues, allValues.begin() + insertBeginPosition);
3006 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3007 if (intAttr.getType() != expectedType)
3018 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3019 InsertOpConstantFolder>(context);
3022 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3026 if (getNumIndices() == 0 && getSourceType() ==
getType())
3050 template <
typename OpType>
3052 ArrayAttr arrayAttr,
3054 StringRef attrName) {
3055 if (arrayAttr.size() > shape.size())
3056 return op.emitOpError(
"expected ")
3057 << attrName <<
" attribute of rank no greater than vector rank";
3064 template <
typename OpType>
3065 static LogicalResult
3067 int64_t
max, StringRef attrName,
3068 bool halfOpen =
true) {
3069 for (
auto attr : arrayAttr) {
3070 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3074 if (val < min || val >= upper)
3075 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3076 <<
min <<
", " << upper <<
")";
3084 template <
typename OpType>
3085 static LogicalResult
3088 bool halfOpen =
true, int64_t
min = 0) {
3089 for (
auto [index, attrDimPair] :
3091 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3092 int64_t
max = std::get<1>(attrDimPair);
3095 if (val < min || val >=
max)
3096 return op.emitOpError(
"expected ")
3097 << attrName <<
" dimension " << index <<
" to be confined to ["
3098 <<
min <<
", " <<
max <<
")";
3108 template <
typename OpType>
3110 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3112 bool halfOpen =
true, int64_t
min = 1) {
3113 assert(arrayAttr1.size() <= shape.size());
3114 assert(arrayAttr2.size() <= shape.size());
3115 for (
auto [index, it] :
3117 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3118 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3119 int64_t
max = std::get<2>(it);
3122 if (val1 + val2 < 0 || val1 + val2 >=
max)
3123 return op.emitOpError(
"expected sum(")
3124 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3125 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3132 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3139 auto sourceVectorType = getSourceVectorType();
3140 auto destVectorType = getDestVectorType();
3141 auto offsets = getOffsetsAttr();
3142 auto strides = getStridesAttr();
3143 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3145 "expected offsets of same size as destination vector rank");
3146 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3147 return emitOpError(
"expected strides of same size as source vector rank");
3148 if (sourceVectorType.getRank() > destVectorType.getRank())
3150 "expected source rank to be no greater than destination rank");
3152 auto sourceShape = sourceVectorType.getShape();
3153 auto destShape = destVectorType.getShape();
3155 destShape.size() - sourceShape.size(), 0);
3156 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3157 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3158 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3167 offName,
"source vector shape",
3171 unsigned rankDiff = destShape.size() - sourceShape.size();
3172 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3173 if (sourceVectorType.getScalableDims()[idx] !=
3174 destVectorType.getScalableDims()[idx + rankDiff]) {
3175 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3178 if (sourceVectorType.getScalableDims()[idx]) {
3179 auto sourceSize = sourceShape[idx];
3180 auto destSize = destShape[idx + rankDiff];
3181 if (sourceSize != destSize) {
3182 return emitOpError(
"expected size at idx=")
3184 << (
" to match the corresponding base size from the input "
3186 << sourceSize << (
" vs ") << destSize << (
")");
3197 class FoldInsertStridedSliceSplat final
3202 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3205 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3207 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3209 if (!srcSplatOp || !destSplatOp)
3212 if (srcSplatOp.getInput() != destSplatOp.getInput())
3215 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3222 class FoldInsertStridedSliceOfExtract final
3227 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3229 auto extractStridedSliceOp =
3230 insertStridedSliceOp.getSource()
3231 .getDefiningOp<vector::ExtractStridedSliceOp>();
3233 if (!extractStridedSliceOp)
3236 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3240 if (extractStridedSliceOp.getStrides() !=
3241 insertStridedSliceOp.getStrides() ||
3242 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3245 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3252 class InsertStridedSliceConstantFolder final
3259 static constexpr int64_t vectorSizeFoldThreshold = 256;
3270 VectorType destTy = destVector.getType();
3271 if (destTy.isScalable())
3275 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3276 !destVector.hasOneUse())
3279 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3287 if (op.hasNonUnitStrides())
3290 VectorType sliceVecTy = sourceValue.getType();
3292 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3302 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3303 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3304 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3307 currDestPosition.begin() + rankDifference, currDestPosition.end());
3311 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3312 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3313 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3314 "Invalid slice element");
3315 newValues[linearizedPosition] = *sliceValuesIt;
3328 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3330 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3331 InsertStridedSliceConstantFolder>(context);
3334 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3335 if (getSourceVectorType() == getDestVectorType())
3352 p <<
" " << getLhs() <<
", " << getRhs();
3354 p <<
", " << getAcc();
3357 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3368 if (operandsInfo.size() < 2)
3370 "expected at least 2 operands");
3371 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3372 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3375 "expected vector type for operand #1");
3380 vRHS.getScalableDims()[0]};
3382 vLHS.getElementType(), scalableDimsRes);
3386 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3392 OuterProductOp::getKindAttrName(result.
name),
3394 OuterProductOp::getDefaultKind()));
3400 (operandsInfo.size() > 2 &&
3406 Type tRHS = getOperandTypeRHS();
3407 VectorType vLHS = getOperandVectorTypeLHS(),
3408 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3409 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3411 if (vLHS.getRank() != 1)
3412 return emitOpError(
"expected 1-d vector for operand #1");
3416 if (vRHS.getRank() != 1)
3417 return emitOpError(
"expected 1-d vector for operand #2");
3418 if (vRES.getRank() != 2)
3419 return emitOpError(
"expected 2-d vector result");
3420 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3421 return emitOpError(
"expected #1 operand dim to match result dim #1");
3422 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3423 return emitOpError(
"expected #2 operand dim to match result dim #2");
3424 if (vLHS.isScalable() && !vRHS.isScalable()) {
3428 "expected either both or only #2 operand dim to be scalable");
3432 if (vRES.getRank() != 1)
3433 return emitOpError(
"expected 1-d vector result");
3434 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3435 return emitOpError(
"expected #1 operand dim to match result dim #1");
3438 if (vACC && vACC != vRES)
3439 return emitOpError(
"expected operand #3 of same type as result type");
3443 return emitOpError(
"unsupported outerproduct type");
3452 Type OuterProductOp::getExpectedMaskType() {
3453 auto vecType = this->getResultVectorType();
3456 vecType.getScalableDims());
3468 ArrayAttr offsets, ArrayAttr sizes,
3469 ArrayAttr strides) {
3470 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3472 shape.reserve(vectorType.getRank());
3474 for (
unsigned e = offsets.size(); idx < e; ++idx)
3475 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3476 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3477 shape.push_back(vectorType.getShape()[idx]);
3480 vectorType.getScalableDims());
3493 offsetsAttr, sizesAttr, stridesAttr));
3494 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3498 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3503 auto type = getSourceVectorType();
3504 auto offsets = getOffsetsAttr();
3505 auto sizes = getSizesAttr();
3506 auto strides = getStridesAttr();
3507 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3509 "expected offsets, sizes and strides attributes of same size");
3511 auto shape = type.getShape();
3512 auto offName = getOffsetsAttrName();
3513 auto sizesName = getSizesAttrName();
3514 auto stridesName = getStridesAttrName();
3530 shape, offName, sizesName,
3535 offsets, sizes, strides);
3536 if (getResult().
getType() != resultType)
3537 return emitOpError(
"expected result type to be ") << resultType;
3539 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3540 if (type.getScalableDims()[idx]) {
3541 auto inputDim = type.getShape()[idx];
3542 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3543 if (inputDim != inputSize)
3544 return emitOpError(
"expected size at idx=")
3546 << (
" to match the corresponding base size from the input "
3548 << inputSize << (
" vs ") << inputDim << (
")");
3558 static LogicalResult
3561 auto getElement = [](ArrayAttr array,
int idx) {
3562 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3564 ArrayAttr extractOffsets = op.getOffsets();
3566 ArrayAttr extractSizes = op.getSizes();
3567 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3569 if (op.getSourceVectorType().getRank() !=
3570 insertOp.getSourceVectorType().getRank())
3572 ArrayAttr insertOffsets = insertOp.getOffsets();
3573 ArrayAttr insertStrides = insertOp.getStrides();
3576 if (extractOffsets.size() > insertOffsets.size())
3578 bool patialoverlap =
false;
3579 bool disjoint =
false;
3581 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3582 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3584 int64_t start = getElement(insertOffsets, dim);
3585 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3586 int64_t offset = getElement(extractOffsets, dim);
3587 int64_t size = getElement(extractSizes, dim);
3589 if (start <= offset && offset < end) {
3592 if (offset + size > end)
3593 patialoverlap =
true;
3594 offsetDiffs.push_back(offset - start);
3601 if (!disjoint && !patialoverlap) {
3602 op.setOperand(insertOp.getSource());
3611 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3621 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3622 if (getSourceVectorType() == getResult().
getType())
3637 class StridedSliceConstantMaskFolder final
3642 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3646 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3647 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3648 if (!constantMaskOp)
3651 if (extractStridedSliceOp.hasNonUnitStrides())
3664 sliceMaskDimSizes.reserve(maskDimSizes.size());
3665 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3666 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3667 int64_t sliceMaskDimSize =
std::max(
3668 static_cast<int64_t
>(0),
3669 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3670 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3673 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3674 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3675 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3678 if (llvm::is_contained(sliceMaskDimSizes, 0))
3679 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3684 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3691 class StridedSliceSplatConstantFolder final
3696 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3700 Value sourceVector = extractStridedSliceOp.getVector();
3705 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3719 class StridedSliceNonSplatConstantFolder final
3724 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3728 Value sourceVector = extractStridedSliceOp.getVector();
3734 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3735 if (!dense || dense.isSplat())
3739 if (extractStridedSliceOp.hasNonUnitStrides())
3742 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3746 VectorType sliceVecTy = extractStridedSliceOp.getType();
3748 int64_t sliceRank = sliceVecTy.getRank();
3760 auto denseValuesBegin = dense.value_begin<
Attribute>();
3762 sliceValues.reserve(sliceVecTy.getNumElements());
3765 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3766 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3768 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3772 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3773 sliceVecTy.getNumElements() &&
3774 "Invalid number of slice elements");
3784 class StridedSliceBroadcast final
3796 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3797 auto dstVecType = llvm::cast<VectorType>(op.getType());
3798 unsigned dstRank = dstVecType.getRank();
3799 unsigned rankDiff = dstRank - srcRank;
3803 bool lowerDimMatch =
true;
3804 for (
unsigned i = 0; i < srcRank; i++) {
3805 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3806 lowerDimMatch =
false;
3815 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3816 if (!lowerDimMatch && !isScalarSrc) {
3817 source = rewriter.
create<ExtractStridedSliceOp>(
3818 op->getLoc(), source,
3829 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3835 auto splat = op.getVector().getDefiningOp<SplatOp>();
3859 class ContiguousExtractStridedSliceToExtract final
3866 if (op.hasNonUnitStrides())
3868 Value source = op.getOperand();
3869 auto sourceType = cast<VectorType>(source.
getType());
3870 if (sourceType.isScalable() || sourceType.getRank() == 0)
3879 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3880 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3887 if (numOffsets == 0)
3892 if (numOffsets == sourceType.getRank() &&
3893 static_cast<int>(sizes.size()) == sourceType.getRank())
3897 for (
int i = 0; i < numOffsets; ++i) {
3905 while (sizes[numOffsets] == 1 &&
3906 numOffsets <
static_cast<int>(sizes.size()) - 1) {
3911 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
3912 Value extract = rewriter.
create<vector::ExtractOp>(op->getLoc(), source,
3921 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3925 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3926 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3927 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3937 VectorType vectorType,
Value source,
3938 ValueRange indices, AffineMapAttr permutationMapAttr,
3939 ArrayAttr inBoundsAttr) {
3940 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3941 Value padding = builder.
create<arith::ConstantOp>(
3943 build(builder, result, vectorType, source, indices, permutationMapAttr,
3944 padding,
Value(), inBoundsAttr);
3949 VectorType vectorType,
Value source,
3953 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3957 build(builder, result, vectorType, source, indices, permutationMapAttr,
3963 VectorType vectorType,
Value source,
3967 llvm::cast<ShapedType>(source.
getType()), vectorType);
3969 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3973 build(builder, result, vectorType, source, indices, permutationMapAttr,
3975 Value(), inBoundsAttr);
3981 VectorType vectorType,
Value source,
3984 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3985 Value padding = builder.
create<arith::ConstantOp>(
3987 build(builder, result, vectorType, source, indices, padding, inBounds);
3990 template <
typename EmitFun>
3992 EmitFun emitOpError) {
3994 for (
auto expr : permutationMap.
getResults()) {
3995 auto dim = dyn_cast<AffineDimExpr>(expr);
3996 auto zero = dyn_cast<AffineConstantExpr>(expr);
3998 if (zero.getValue() != 0) {
4000 "requires a projected permutation_map (at most one dim or the zero "
4001 "constant can appear in each result)");
4006 return emitOpError(
"requires a projected permutation_map (at most one "
4007 "dim or the zero constant can appear in each result)");
4009 if (seen[dim.getPosition()]) {
4011 "requires a permutation_map that is a permutation (found one dim "
4012 "used more than once)");
4014 seen[dim.getPosition()] =
true;
4019 static LogicalResult
4021 VectorType vectorType, VectorType maskType,
4022 VectorType inferredMaskType,
AffineMap permutationMap,
4023 ArrayAttr inBounds) {
4024 if (op->hasAttr(
"masked")) {
4025 return op->emitOpError(
"masked attribute has been removed. "
4026 "Use in_bounds instead.");
4029 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4030 return op->emitOpError(
4031 "requires source to be a memref or ranked tensor type");
4033 auto elementType = shapedType.getElementType();
4034 DataLayout dataLayout = DataLayout::closest(op);
4035 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4037 unsigned sourceVecSize =
4039 vectorElementType.getShape().back();
4040 unsigned resultVecSize =
4042 vectorType.getShape().back();
4043 if (resultVecSize % sourceVecSize != 0)
4044 return op->emitOpError(
4045 "requires the bitwidth of the minor 1-D vector to be an integral "
4046 "multiple of the bitwidth of the minor 1-D vector of the source");
4048 unsigned sourceVecEltRank = vectorElementType.getRank();
4049 unsigned resultVecRank = vectorType.getRank();
4050 if (sourceVecEltRank > resultVecRank)
4051 return op->emitOpError(
4052 "requires source vector element and vector result ranks to match.");
4053 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4056 return op->emitOpError(
"requires a permutation_map with result dims of "
4057 "the same rank as the vector type");
4060 return op->emitOpError(
"does not support masks with vector element type");
4063 unsigned minorSize =
4064 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4065 unsigned resultVecSize =
4068 return op->emitOpError(
4069 "requires the bitwidth of the minor 1-D vector to be an integral "
4070 "multiple of the bitwidth of the source element type");
4074 return op->emitOpError(
"requires a permutation_map with result dims of "
4075 "the same rank as the vector type");
4079 return op->emitOpError(
"requires permutation_map without symbols");
4081 if (permutationMap.
getNumInputs() != shapedType.getRank())
4082 return op->emitOpError(
"requires a permutation_map with input dims of the "
4083 "same rank as the source type");
4085 if (maskType && maskType != inferredMaskType)
4086 return op->emitOpError(
"inferred mask type (")
4087 << inferredMaskType <<
") and mask operand type (" << maskType
4090 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4091 return op->emitOpError(
"expects the in_bounds attr of same rank "
4092 "as permutation_map results: ")
4094 <<
" vs inBounds of size: " << inBounds.size();
4101 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4102 if (op.getPermutationMap().isMinorIdentity())
4103 elidedAttrs.push_back(op.getPermutationMapAttrName());
4105 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4106 elidedAttrs.push_back(op.getInBoundsAttrName());
4111 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
4113 p <<
", " << getMask();
4122 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4145 if (hasMask.succeeded()) {
4152 if (types.size() != 2)
4153 return parser.
emitError(typesLoc,
"requires two types");
4155 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4156 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4157 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4158 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4160 return parser.
emitError(typesLoc,
"requires vector type");
4161 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4168 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4170 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4172 if (!inBoundsAttr) {
4182 if (hasMask.succeeded()) {
4183 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4185 maskInfo.
location,
"does not support masks with vector element type");
4188 "expected the same rank for the vector and the "
4189 "results of the permutation map");
4197 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4199 {1, static_cast<int32_t>(indexInfo.size()), 1,
4200 static_cast<int32_t>(hasMask.succeeded())}));
4206 ShapedType shapedType = getShapedType();
4208 VectorType maskType = getMaskType();
4209 auto paddingType = getPadding().getType();
4210 auto permutationMap = getPermutationMap();
4211 VectorType inferredMaskType =
4214 auto sourceElementType = shapedType.getElementType();
4216 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4217 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4219 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4220 shapedType, vectorType, maskType,
4221 inferredMaskType, permutationMap, getInBounds())))
4224 if (
auto sourceVectorElementType =
4225 llvm::dyn_cast<VectorType>(sourceElementType)) {
4228 if (sourceVectorElementType != paddingType)
4230 "requires source element type and padding type to match.");
4234 if (!VectorType::isValidElementType(paddingType))
4235 return emitOpError(
"requires valid padding vector elemental type");
4238 if (paddingType != sourceElementType)
4240 "requires formal padding and source of the same elemental type");
4244 [&](Twine t) {
return emitOpError(t); });
4251 Type TransferReadOp::getExpectedMaskType() {
4255 template <
typename TransferOp>
4256 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4259 if (op.getShapedType().isDynamicDim(indicesIdx))
4261 Value index = op.getIndices()[indicesIdx];
4263 if (!cstOp.has_value())
4266 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4267 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4269 return cstOp.value() + vectorSize <= sourceSize;
4272 template <
typename TransferOp>
4276 if (op.getTransferRank() == 0)
4279 bool changed =
false;
4281 newInBounds.reserve(op.getTransferRank());
4286 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4288 if (op.isDimInBounds(i)) {
4289 newInBounds.push_back(
true);
4294 bool inBounds =
false;
4295 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4298 dimExpr.getPosition());
4299 nonBcastDims.push_back(i);
4302 newInBounds.push_back(inBounds);
4304 changed |= inBounds;
4310 bool allNonBcastDimsInBounds = llvm::all_of(
4311 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4312 if (allNonBcastDimsInBounds) {
4314 changed |= !newInBounds[idx];
4315 newInBounds[idx] =
true;
4327 template <
typename TransferOp>
4329 auto mask = op.getMask();
4336 op.getMaskMutable().clear();
4350 static Value foldRAW(TransferReadOp readOp) {
4351 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4353 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4356 return defWrite.getVector();
4358 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4359 cast<VectorTransferOpInterface>(readOp.getOperation())))
4361 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4367 if (
Value vec = foldRAW(*
this))
4381 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4385 void TransferReadOp::getEffects(
4388 if (llvm::isa<MemRefType>(getShapedType()))
4394 if (hasPureTensorSemantics())
4422 struct TransferReadAfterWriteToBroadcast
4428 if (readOp.hasOutOfBoundsDim() ||
4429 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4431 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4436 if (readOp.getTransferChunkAccessed() !=
4437 defWrite.getTransferChunkAccessed())
4444 if (readOp.getIndices() != defWrite.getIndices() ||
4445 readOp.getMask() != defWrite.getMask())
4447 Value vec = defWrite.getVector();
4469 broadcastShape[pos.value()] = destShape[pos.index()];
4470 broadcastScalableFlags[pos.value()] =
4471 readOp.getVectorType().getScalableDims()[pos.index()];
4474 broadcastShape, defWrite.getVectorType().getElementType(),
4475 broadcastScalableFlags);
4476 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4487 results.
add<TransferReadAfterWriteToBroadcast>(context);
4497 AffineMapAttr permutationMapAttr,
4499 ArrayAttr inBoundsAttr) {
4500 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4501 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4502 mask, inBoundsAttr);
4508 AffineMapAttr permutationMapAttr,
4509 ArrayAttr inBoundsAttr) {
4510 build(builder, result, vector, dest, indices, permutationMapAttr,
4511 Value(), inBoundsAttr);
4522 (inBounds && !inBounds.value().empty())
4525 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4526 build(builder, result, vector, dest, indices, permutationMapAttr,
4527 Value(), inBoundsAttr);
4535 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4537 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4538 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4554 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4559 if (types.size() != 2)
4560 return parser.
emitError(typesLoc,
"requires two types");
4562 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4564 return parser.
emitError(typesLoc,
"requires vector type");
4565 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4566 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4567 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4568 auto permMapAttrName =
4569 TransferWriteOp::getPermutationMapAttrName(result.
name);
4576 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4578 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4580 if (!inBoundsAttr) {
4589 if (hasMask.succeeded()) {
4590 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4592 maskInfo.
location,
"does not support masks with vector element type");
4595 "expected the same rank for the vector and the "
4596 "results of the permutation map");
4602 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4604 {1, 1, static_cast<int32_t>(indexInfo.size()),
4605 static_cast<int32_t>(hasMask.succeeded())}));
4606 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4611 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4613 p <<
", " << getMask();
4620 ShapedType shapedType = getShapedType();
4622 VectorType maskType = getMaskType();
4623 auto permutationMap = getPermutationMap();
4624 VectorType inferredMaskType =
4628 if (llvm::size(
getIndices()) != shapedType.getRank())
4629 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4633 if (hasBroadcastDim())
4634 return emitOpError(
"should not have broadcast dimensions");
4636 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4637 shapedType, vectorType, maskType,
4638 inferredMaskType, permutationMap, getInBounds())))
4642 [&](Twine t) {
return emitOpError(t); });
4649 Type TransferWriteOp::getExpectedMaskType() {
4670 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4674 if (write.getTransferRank() == 0)
4676 auto rankedTensorType =
4677 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4679 if (!rankedTensorType)
4682 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4686 if (read.getTransferRank() == 0)
4689 if (!read.getPermutationMap().isMinorIdentity() ||
4690 !write.getPermutationMap().isMinorIdentity())
4693 if (read.getTransferRank() != write.getTransferRank())
4696 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4699 if (read.getSource().getType() != rankedTensorType)
4702 if (read.getVectorType() != write.getVectorType())
4705 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4708 auto isNotConstantZero = [](
Value v) {
4710 return !cstOp.has_value() || cstOp.value() != 0;
4712 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4713 llvm::any_of(write.getIndices(), isNotConstantZero))
4716 results.push_back(read.getSource());
4720 static bool checkSameValueWAR(vector::TransferReadOp read,
4721 vector::TransferWriteOp write) {
4722 return read.getSource() == write.getSource() &&
4723 read.getIndices() == write.getIndices() &&
4724 read.getPermutationMap() == write.getPermutationMap() &&
4725 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4742 static LogicalResult foldWAR(TransferWriteOp write,
4744 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4746 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4750 if (!checkSameValueWAR(read, write))
4752 results.push_back(read.getSource());
4756 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4758 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4760 if (succeeded(foldWAR(*
this, results)))
4769 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4773 void TransferWriteOp::getEffects(
4776 if (llvm::isa<MemRefType>(getShapedType()))
4782 if (hasPureTensorSemantics())
4817 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4819 vector::TransferWriteOp writeToModify = writeOp;
4822 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4826 writeToModify.getSourceMutable().assign(defWrite.getSource());
4831 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4832 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4836 if (!defWrite->hasOneUse())
4838 writeToModify = defWrite;
4839 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4868 struct SwapExtractSliceOfTransferWrite
4875 if (!insertOp.hasUnitStride())
4878 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4879 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4881 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4882 if (!transferOp || !transferOp->hasOneUse())
4887 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4889 "use-def chain is rank-reducing");
4893 if (!extractOp.hasZeroOffset()) {
4895 "ExtractSliceOp has non-zero offset");
4899 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
4903 "TranferWriteOp has non-zero offset");
4907 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4909 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
4912 for (
auto [insertSize, extractSize] :
4913 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4916 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
4921 assert(transferOp.getVectorType().hasStaticShape() &&
4922 "expected vector to have a static shape");
4925 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4926 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
4928 insertOp,
"TransferWriteOp may not write the full tensor.");
4934 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
4935 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4936 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4937 insertOp.getMixedStrides());
4938 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
4939 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4940 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4943 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4953 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4960 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
4962 MemRefType memRefTy) {
4965 if (!vecTy.isScalable() &&
4966 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
4970 return op->
emitOpError(
"most minor memref dim must have unit stride");
4978 if (failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
4982 Type memElemTy = memRefTy.getElementType();
4983 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4984 if (memVecTy != resVecTy)
4985 return emitOpError(
"base memref and result vector types should match");
4986 memElemTy = memVecTy.getElementType();
4989 if (resVecTy.getElementType() != memElemTy)
4990 return emitOpError(
"base and result element types should match");
4991 if (llvm::size(
getIndices()) != memRefTy.getRank())
4992 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5010 if (failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5014 Type memElemTy = memRefTy.getElementType();
5015 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5016 if (memVecTy != valueVecTy)
5018 "base memref and valueToStore vector types should match");
5019 memElemTy = memVecTy.getElementType();
5022 if (valueVecTy.getElementType() != memElemTy)
5023 return emitOpError(
"base and valueToStore element type should match");
5024 if (llvm::size(
getIndices()) != memRefTy.getRank())
5025 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5029 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5039 VectorType maskVType = getMaskVectorType();
5040 VectorType passVType = getPassThruVectorType();
5044 if (resVType.getElementType() != memType.getElementType())
5045 return emitOpError(
"base and result element type should match");
5046 if (llvm::size(
getIndices()) != memType.getRank())
5047 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5048 if (resVType.getShape() != maskVType.getShape())
5049 return emitOpError(
"expected result shape to match mask shape");
5050 if (resVType != passVType)
5051 return emitOpError(
"expected pass_thru of same type as result type");
5064 load, load.getType(), load.getBase(), load.getIndices());
5067 rewriter.
replaceOp(load, load.getPassThru());
5072 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5079 results.
add<MaskedLoadFolder>(context);
5093 VectorType maskVType = getMaskVectorType();
5097 if (valueVType.getElementType() != memType.getElementType())
5098 return emitOpError(
"base and valueToStore element type should match");
5099 if (llvm::size(
getIndices()) != memType.getRank())
5100 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5101 if (valueVType.getShape() != maskVType.getShape())
5102 return emitOpError(
"expected valueToStore shape to match mask shape");
5115 store, store.getValueToStore(), store.getBase(), store.getIndices());
5123 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5130 results.
add<MaskedStoreFolder>(context);
5133 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5143 VectorType indVType = getIndexVectorType();
5144 VectorType maskVType = getMaskVectorType();
5146 ShapedType baseType = getBaseType();
5148 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5149 return emitOpError(
"requires base to be a memref or ranked tensor type");
5151 if (resVType.getElementType() != baseType.getElementType())
5152 return emitOpError(
"base and result element type should match");
5153 if (llvm::size(
getIndices()) != baseType.getRank())
5154 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5155 if (resVType.getShape() != indVType.getShape())
5156 return emitOpError(
"expected result dim to match indices dim");
5157 if (resVType.getShape() != maskVType.getShape())
5158 return emitOpError(
"expected result dim to match mask dim");
5159 if (resVType != getPassThruVectorType())
5160 return emitOpError(
"expected pass_thru of same type as result type");
5168 Type GatherOp::getExpectedMaskType() {
5169 auto vecType = this->getIndexVectorType();
5172 vecType.getScalableDims());
5175 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5189 rewriter.
replaceOp(gather, gather.getPassThru());
5194 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5201 results.
add<GatherFolder>(context);
5209 VectorType indVType = getIndexVectorType();
5210 VectorType maskVType = getMaskVectorType();
5214 if (valueVType.getElementType() != memType.getElementType())
5215 return emitOpError(
"base and valueToStore element type should match");
5216 if (llvm::size(
getIndices()) != memType.getRank())
5217 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5218 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5219 return emitOpError(
"expected valueToStore dim to match indices dim");
5220 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5221 return emitOpError(
"expected valueToStore dim to match mask dim");
5240 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5247 results.
add<ScatterFolder>(context);
5255 VectorType maskVType = getMaskVectorType();
5256 VectorType passVType = getPassThruVectorType();
5260 if (resVType.getElementType() != memType.getElementType())
5261 return emitOpError(
"base and result element type should match");
5262 if (llvm::size(
getIndices()) != memType.getRank())
5263 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5264 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5265 return emitOpError(
"expected result dim to match mask dim");
5266 if (resVType != passVType)
5267 return emitOpError(
"expected pass_thru of same type as result type");
5280 expand, expand.getType(), expand.getBase(), expand.getIndices());
5283 rewriter.
replaceOp(expand, expand.getPassThru());
5288 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5295 results.
add<ExpandLoadFolder>(context);
5303 VectorType maskVType = getMaskVectorType();
5307 if (valueVType.getElementType() != memType.getElementType())
5308 return emitOpError(
"base and valueToStore element type should match");
5309 if (llvm::size(
getIndices()) != memType.getRank())
5310 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5311 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5312 return emitOpError(
"expected valueToStore dim to match mask dim");
5317 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5325 compress, compress.getValueToStore(), compress.getBase(),
5326 compress.getIndices());
5334 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5341 results.
add<CompressStoreFolder>(context);
5350 setResultRanges(getResult(), argRanges.front());
5356 unsigned rankA = a.size();
5357 unsigned rankB = b.size();
5358 assert(rankA < rankB);
5360 auto isOne = [](int64_t v) {
return v == 1; };
5364 if (rankA == 0 && llvm::all_of(b, isOne))
5369 while (i < rankA &&
j < rankB) {
5370 int64_t dimA = a[i];
5372 while (dimB < dimA &&
j < rankB)
5380 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5382 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5386 return i == rankA &&
j == rankB;
5389 static LogicalResult verifyVectorShapeCast(
Operation *op,
5390 VectorType sourceVectorType,
5391 VectorType resultVectorType) {
5393 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5394 return op->
emitOpError(
"source/result vectors must have same element type");
5395 auto sourceShape = sourceVectorType.getShape();
5396 auto resultShape = resultVectorType.getShape();
5399 int64_t sourceDimProduct = std::accumulate(
5400 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5401 int64_t resultDimProduct = std::accumulate(
5402 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5403 if (sourceDimProduct != resultDimProduct)
5404 return op->
emitOpError(
"source/result number of elements must match");
5407 unsigned sourceRank = sourceVectorType.getRank();
5408 unsigned resultRank = resultVectorType.getRank();
5409 if (sourceRank < resultRank) {
5410 if (!isValidShapeCast(sourceShape, resultShape))
5412 }
else if (sourceRank > resultRank) {
5413 if (!isValidShapeCast(resultShape, sourceShape))
5418 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5419 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5420 if (sourceNScalableDims != resultNScalableDims)
5421 return op->
emitOpError(
"different number of scalable dims at source (")
5422 << sourceNScalableDims <<
") and result (" << resultNScalableDims
5424 sourceVectorType.getNumDynamicDims();
5430 auto sourceVectorType =
5431 llvm::dyn_cast_or_null<VectorType>(getSource().
getType());
5432 auto resultVectorType =
5433 llvm::dyn_cast_or_null<VectorType>(getResult().
getType());
5436 if (sourceVectorType && resultVectorType)
5437 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5448 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5449 if (getResult().
getType() == otherOp.getSource().getType())
5450 return otherOp.getSource();
5453 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5454 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
5455 if (srcType.getRank() < resultType.getRank()) {
5456 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5458 }
else if (srcType.getRank() > resultType.getRank()) {
5459 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5465 setOperand(otherOp.getSource());
5470 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5471 if (bcastOp.getSourceType() ==
getType())
5472 return bcastOp.getSource();
5480 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5487 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5491 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5507 static VectorType trimTrailingOneDims(VectorType oldType) {
5514 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5515 newShape = newShape.drop_back(1);
5516 newScalableDims = newScalableDims.drop_back(1);
5521 if (newShape.empty()) {
5522 newShape = oldShape.take_back();
5523 newScalableDims = oldScalableDims.take_back();
5526 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5541 class ShapeCastCreateMaskFolderTrailingOneDim final
5548 Value shapeOpSrc = shapeOp->getOperand(0);
5549 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5550 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5551 if (!createMaskOp && !constantMaskOp)
5554 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5555 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5557 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5558 if (newVecType != shapeOpResTy)
5561 auto numDimsToDrop =
5562 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5569 auto maskOperands = createMaskOp.getOperands();
5570 auto numMaskOperands = maskOperands.size();
5573 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5575 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5576 if (!constant || (constant.value() != 1))
5580 maskOperands.drop_back(numDimsToDrop);
5587 if (constantMaskOp) {
5588 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5589 auto numMaskOperands = maskDimSizes.size();
5592 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5594 if (maskDimSizes[i] != 1)
5598 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5613 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5620 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5625 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5626 broadcastSourceShape = srcType.getShape();
5628 shapeCastOp.getResultVectorType().getShape();
5632 if (broadcastSourceShape ==
5633 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5635 shapeCastOp, shapeCastOp.getResultVectorType(),
5636 broadcastOp.getSource());
5642 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5643 if (srcType.getNumElements() ==
5644 shapeCastOp.getResultVectorType().getNumElements()) {
5646 shapeCastOp, shapeCastOp.getResultVectorType(),
5647 broadcastOp.getSource());
5660 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5661 ShapeCastBroadcastFolder>(context);
5669 auto sourceVectorType = getSourceVectorType();
5670 auto resultVectorType = getResultVectorType();
5672 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5673 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5674 return emitOpError(
"dimension size mismatch at: ") << i;
5677 DataLayout dataLayout = DataLayout::closest(*
this);
5678 auto sourceElementBits =
5680 auto resultElementBits =
5683 if (sourceVectorType.getRank() == 0) {
5684 if (sourceElementBits != resultElementBits)
5685 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5686 "types must be equal");
5687 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5688 resultElementBits * resultVectorType.getShape().back()) {
5690 "source/result bitwidth of the minor 1-D vectors must be equal");
5702 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5703 if (getResult().
getType() == otherOp.getSource().getType())
5704 return otherOp.getSource();
5706 setOperand(otherOp.getSource());
5710 Attribute sourceConstant = adaptor.getSource();
5711 if (!sourceConstant)
5714 Type srcElemType = getSourceVectorType().getElementType();
5715 Type dstElemType = getResultVectorType().getElementType();
5717 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5718 if (floatPack.isSplat()) {
5719 auto splat = floatPack.getSplatValue<FloatAttr>();
5722 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5723 uint32_t bits =
static_cast<uint32_t
>(
5724 splat.getValue().bitcastToAPInt().getZExtValue());
5726 bits = (bits << 16) | (bits & 0xffff);
5727 APInt intBits(32, bits);
5728 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5734 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5735 if (intPack.isSplat()) {
5736 auto splat = intPack.getSplatValue<IntegerAttr>();
5738 if (llvm::isa<IntegerType>(dstElemType)) {
5743 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5744 APInt intBits = splat.getValue().zext(dstBitWidth);
5747 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5748 intBits = (intBits << srcBitWidth) | intBits;
5763 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5766 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5775 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5776 VectorType vectorType =
5780 memRefType.getMemorySpace()));
5785 if (!canonicalType.getLayout().isIdentity())
5786 return emitOpError(
"expects operand to be a memref with identity layout");
5787 if (!getResultMemRefType().getLayout().isIdentity())
5788 return emitOpError(
"expects result to be a memref with identity layout");
5789 if (getResultMemRefType().getMemorySpace() !=
5791 return emitOpError(
"expects result in same memory space");
5794 auto resultType = getResultMemRefType();
5798 "expects result and operand with same underlying scalar type: ")
5800 if (extractShape(sourceType) != extractShape(resultType))
5802 "expects concatenated result and operand shapes to be equal: ")
5813 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5816 for (
unsigned i = 0; i < permutation.size(); ++i) {
5817 transposedShape[i] = vt.getShape()[permutation[i]];
5818 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5823 transposedScalableDims));
5828 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5831 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5833 return attr.reshape(getResultVectorType());
5841 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5850 VectorType vectorType = getSourceVectorType();
5851 VectorType resultType = getResultVectorType();
5852 int64_t rank = resultType.getRank();
5853 if (vectorType.getRank() != rank)
5854 return emitOpError(
"vector result rank mismatch: ") << rank;
5857 int64_t size = perm.size();
5859 return emitOpError(
"transposition length mismatch: ") << size;
5862 if (ta.value() < 0 || ta.value() >= rank)
5863 return emitOpError(
"transposition index out of range: ") << ta.value();
5864 if (seen[ta.value()])
5865 return emitOpError(
"duplicate position index: ") << ta.value();
5866 seen[ta.value()] =
true;
5867 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5868 return emitOpError(
"dimension size mismatch at: ") << ta.value();
5873 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5874 return llvm::to_vector<4>(getResultVectorType().
getShape());
5880 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
5890 for (
auto index : permutation2)
5891 result.push_back(permutation1[index]);
5896 vector::TransposeOp parentTransposeOp =
5897 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5898 if (!parentTransposeOp)
5902 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5905 transposeOp, transposeOp.getResult().getType(),
5906 parentTransposeOp.getVector(), permutation);
5912 struct FoldTransposedScalarBroadcast final
5918 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5922 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5923 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5925 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5940 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5945 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5951 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
5957 Value transposeSrc = transpOp.getVector();
5958 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
5959 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
5960 if (!createMaskOp && !constantMaskOp)
5968 auto maskOperands = createMaskOp.getOperands();
5973 transpOp, transpOp.getResultVectorType(), newOperands);
5978 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5982 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
5989 void vector::TransposeOp::getCanonicalizationPatterns(
5991 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5992 TransposeFolder, FoldTransposeSplat>(context);
6001 assert(kind == ConstantMaskKind::AllTrue ||
6002 kind == ConstantMaskKind::AllFalse);
6003 build(builder, result, type,
6004 kind == ConstantMaskKind::AllTrue
6010 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6012 if (resultType.getRank() == 0) {
6013 if (getMaskDimSizes().size() != 1)
6014 return emitError(
"array attr must have length 1 for 0-D vectors");
6015 auto dim = getMaskDimSizes()[0];
6016 if (dim != 0 && dim != 1)
6017 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6022 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6024 "must specify array attr of size equal vector result rank");
6027 auto resultShape = resultType.getShape();
6028 auto resultScalableDims = resultType.getScalableDims();
6030 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6031 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6033 "array attr of size out of bounds of vector result dimension size");
6034 if (resultScalableDims[index] && maskDimSize != 0 &&
6035 maskDimSize != resultShape[index])
6037 "only supports 'none set' or 'all set' scalable dimensions");
6041 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6042 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6043 if (anyZeros && !allZeros)
6044 return emitOpError(
"expected all mask dim sizes to be zeros, "
6045 "as a result of conjunction with zero mask dim");
6049 bool ConstantMaskOp::isAllOnesMask() {
6052 if (resultType.getRank() == 0) {
6053 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6054 return getMaskDimSizes()[0] == 1;
6056 for (
const auto [resultSize, maskDimSize] :
6057 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6058 if (maskDimSize < resultSize)
6073 build(builder, result, type, operands);
6077 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6079 if (vectorType.getRank() == 0) {
6080 if (getNumOperands() != 1)
6082 "must specify exactly one operand for 0-D create_mask");
6083 }
else if (getNumOperands() !=
6084 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6086 "must specify an operand for each result vector dimension");
6122 VectorType maskType = createMaskOp.getVectorType();
6124 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6127 constexpr std::array<int64_t, 1> rankZeroShape{1};
6128 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6129 if (maskType.getRank() == 0) {
6130 maskTypeDimSizes = rankZeroShape;
6131 maskTypeDimScalableFlags = rankZeroScalableDims;
6137 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6142 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6144 constantDims.push_back(*intSize);
6148 if (vscaleMultiplier < maskTypeDimSizes[i])
6150 constantDims.push_back(*vscaleMultiplier);
6157 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6158 value = std::clamp<int64_t>(value, 0, maskDimSize);
6161 if (llvm::is_contained(constantDims, 0))
6162 constantDims.assign(constantDims.size(), 0);
6175 results.
add<CreateMaskFolder>(context);
6186 assert(maskRegionBuilder &&
6187 "builder callback for 'maskRegion' must be present");
6193 maskRegionBuilder(builder, maskableOp);
6200 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6208 build(builder, result, mask, maskableOp, maskRegionBuilder);
6229 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6236 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6250 result.
types.append(resultTypes);
6256 if (parsePassthru.succeeded())
6264 p <<
" " << getMask();
6266 p <<
", " << getPassthru();
6270 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6277 p <<
" : " << getMask().getType();
6278 if (getNumResults() > 0)
6279 p <<
" -> " << getResultTypes();
6284 MaskOp>::ensureTerminator(region, builder, loc);
6296 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
6299 if (maskedOp == oldYieldOp)
6302 opBuilder.setInsertionPoint(oldYieldOp);
6303 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
6305 oldYieldOp->
erase();
6310 Block &block = getMaskRegion().getBlocks().
front();
6312 return emitOpError(
"expects a terminator within the mask region");
6315 if (numMaskRegionOps > 2)
6316 return emitOpError(
"expects only one operation to mask");
6319 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6321 return emitOpError(
"expects a terminator within the mask region");
6323 if (terminator->getNumOperands() != getNumResults())
6325 "expects number of results to match mask region yielded values");
6328 if (numMaskRegionOps == 1)
6331 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6333 return emitOpError(
"expects a MaskableOpInterface within the mask region");
6337 return emitOpError(
"expects number of results to match maskable operation "
6338 "number of results");
6340 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6342 "expects result type to match maskable operation result type");
6345 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6346 return emitOpError(
"multiple vector results not supported");
6349 Type expectedMaskType = maskableOp.getExpectedMaskType();
6350 if (getMask().
getType() != expectedMaskType)
6351 return emitOpError(
"expects a ")
6352 << expectedMaskType <<
" mask for the maskable operation";
6355 Value passthru = getPassthru();
6357 if (!maskableOp.supportsPassthru())
6359 "doesn't expect a passthru argument for this maskable operation");
6362 return emitOpError(
"expects result when passthru argument is provided");
6365 return emitOpError(
"expects passthru type to match result type");
6372 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6382 Operation *maskableOp = getMaskableOp();
6386 llvm::append_range(results, maskableOp->
getResults());
6398 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6399 if (maskingOp.getMaskableOp())
6402 if (!maskOp.isEmpty())
6405 Block *block = maskOp.getMaskBlock();
6406 auto terminator = cast<vector::YieldOp>(block->
front());
6407 if (terminator.getNumOperands() == 0)
6410 rewriter.
replaceOp(maskOp, terminator.getOperands());
6418 results.
add<ElideEmptyMaskOp>(context);
6425 Block *block = getMaskBlock();
6429 return &block->
front();
6433 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
6440 VectorType srcType = getSourceType();
6441 VectorType initialType = getInitialValueType();
6443 int64_t srcRank = srcType.getRank();
6444 int64_t reductionDim = getReductionDim();
6445 if (reductionDim >= srcRank)
6446 return emitOpError(
"reduction dimension ")
6447 << reductionDim <<
" has to be less than " << srcRank;
6450 int64_t initialValueRank = initialType.getRank();
6451 if (initialValueRank != srcRank - 1)
6452 return emitOpError(
"initial value rank ")
6453 << initialValueRank <<
" has to be equal to " << srcRank - 1;
6459 for (
int i = 0; i < srcRank; i++) {
6460 if (i != reductionDim)
6461 expectedShape.push_back(srcShape[i]);
6463 if (!llvm::equal(initialValueShapes, expectedShape)) {
6464 return emitOpError(
"incompatible input/initial value shapes");
6468 Type eltType = getDestType().getElementType();
6470 return emitOpError(
"unsupported reduction type ")
6471 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
6480 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6481 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6482 StridedSliceConstantMaskFolder, TransposeFolder>(
6491 auto constOperand = adaptor.getInput();
6492 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6501 setResultRanges(getResult(), argRanges.front());
6509 p <<
"(" << getLaneid() <<
")";
6512 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
6513 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
6515 if (!getArgs().empty())
6516 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
6517 if (!getResults().empty())
6518 p <<
" -> (" << getResults().getTypes() <<
')';
6522 !getResults().empty());
6552 llvm::SMLoc inputsOperandsLoc;
6564 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
6575 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.
location);
6583 void WarpExecuteOnLane0Op::getSuccessorRegions(
6597 build(builder, result, resultTypes, laneId, warpSize,
6598 std::nullopt, std::nullopt);
6610 assert(args.size() == blockArgTypes.size());
6614 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
6620 static LogicalResult verifyDistributedType(
Type expanded,
Type distributed,
6623 if (expanded == distributed)
6625 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6626 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6627 if (!expandedVecType || !distributedVecType)
6628 return op->
emitOpError(
"expected vector type for distributed operands.");
6629 if (expandedVecType.getRank() != distributedVecType.getRank() ||
6630 expandedVecType.getElementType() != distributedVecType.getElementType())
6632 "expected distributed vectors to have same rank and element type.");
6635 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
6636 int64_t eDim = expandedVecType.getDimSize(i);
6637 int64_t dDim = distributedVecType.getDimSize(i);
6640 if (eDim % dDim != 0)
6642 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
6643 <<
") to be a multipler of the distributed vector dimension ("
6645 scales[i] = eDim / dDim;
6647 if (std::accumulate(scales.begin(), scales.end(), 1,
6648 std::multiplies<int64_t>()) != warpSize)
6650 <<
"incompatible distribution dimensions from " << expandedVecType
6651 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
6657 if (getArgs().size() != getWarpRegion().getNumArguments())
6659 "expected same number op arguments and block arguments.");
6661 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
6662 if (yield.getNumOperands() != getNumResults())
6664 "expected same number of yield operands and return values.");
6665 int64_t warpSize = getWarpSize();
6666 for (
auto [regionArg, arg] :
6667 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
6668 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
6669 warpSize, getOperation())))
6672 for (
auto [yieldOperand, result] :
6673 llvm::zip_equal(yield.getOperands(), getResults())) {
6674 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
6675 warpSize, getOperation())))
6681 bool WarpExecuteOnLane0Op::areTypesCompatible(
Type lhs,
Type rhs) {
6683 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
6688 arith::FastMathFlagsAttr fastmath,
6695 case CombiningKind::ADD:
6698 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6699 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6701 llvm_unreachable(
"invalid value types for ADD reduction");
6703 case CombiningKind::AND:
6707 case CombiningKind::MAXNUMF:
6708 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6709 "expected float values");
6710 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6712 case CombiningKind::MAXIMUMF:
6713 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6714 "expected float values");
6715 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6717 case CombiningKind::MINNUMF:
6718 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6719 "expected float values");
6720 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6722 case CombiningKind::MINIMUMF:
6723 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6724 "expected float values");
6725 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6727 case CombiningKind::MAXSI:
6731 case CombiningKind::MINSI:
6735 case CombiningKind::MAXUI:
6743 case CombiningKind::MUL:
6746 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6747 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6749 llvm_unreachable(
"invalid value types for MUL reduction");
6751 case CombiningKind::OR:
6755 case CombiningKind::XOR:
6761 assert(result &&
"unknown CombiningKind");
6773 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
6793 return builder.
create<MaskOp>(maskableOp->getLoc(),
6794 maskableOp->getResultTypes(), mask, maskableOp,
6811 mask, newValue, passthru);
6818 #define GET_ATTRDEF_CLASSES
6819 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6821 #define GET_OP_CLASSES
6822 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.