41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
52 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
54 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
75 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
77 for (
bool b : denseElts.getValues<
bool>())
80 else if (!b && val <= 0)
94 auto shape = m.getType().getShape();
97 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
98 if (maskIdx < dimSize)
111 auto maskOperands = m.getOperands();
112 for (
Value operand : maskOperands) {
113 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
128 builder.
create<vector::YieldOp>(loc);
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
190 VectorType vectorType) {
193 if (shapedType.getRank() == 0 &&
199 shapedType.getRank(),
201 shapedType.getContext());
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
222 m_Constant<DenseElementsAttr>(&splatAttr)) ||
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
255 if (transferA.getVectorType() != transferB.getVectorType())
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
264 if (i < rankOffset) {
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
272 if (testDynamicValueUsingBounds) {
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
295 if (testDynamicValueUsingBounds) {
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (
std::abs(computeDelta.value()) >= vectorDim)
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
321 testDynamicValueUsingBounds);
331 for (
auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
334 if (posInDim < dimSize + offsetInDim)
338 posInDim = offsetInDim;
348 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
350 assert(constOp &&
"Unexpected non-constant index");
351 return constOp.value();
361 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
373 llvm::transform(foldResults, std::back_inserter(values),
375 if (
auto attr = dyn_cast<Attribute>(foldResult))
377 .create<arith::ConstantIndexOp>(
378 loc, cast<IntegerAttr>(attr).getInt())
381 return cast<Value>(foldResult);
392 auto lhs = mul.getLhs();
393 auto rhs = mul.getRhs();
394 if (lhs.getDefiningOp<vector::VectorScaleOp>())
396 if (rhs.getDefiningOp<vector::VectorScaleOp>())
444 void VectorDialect::initialize() {
446 #define GET_ATTRDEF_LIST
447 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
452 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
455 addInterfaces<VectorInlinerInterface>();
457 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
458 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
460 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
462 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
463 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
464 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
472 if (isa<ub::PoisonAttrInterface>(value))
475 return arith::ConstantOp::materialize(builder, value, type, loc);
491 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
494 CombiningKind
kind) {
498 reductionDims.push_back(en.index());
499 build(builder, result,
kind, source, acc, reductionDims);
502 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
504 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
509 std::optional<SmallVector<int64_t, 4>>
510 MultiDimReductionOp::getShapeForUnroll() {
511 return llvm::to_vector<4>(getSourceVectorType().
getShape());
517 Type inferredReturnType;
518 auto sourceScalableDims = getSourceVectorType().getScalableDims();
519 for (
auto [dimIdx, dimSize] :
521 if (!llvm::any_of(getReductionDims(),
522 [dimIdx = dimIdx](int64_t reductionDimIdx) {
523 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
525 targetShape.push_back(dimSize);
526 scalableDims.push_back(sourceScalableDims[dimIdx]);
529 if (targetShape.empty())
530 inferredReturnType = getSourceVectorType().getElementType();
533 targetShape, getSourceVectorType().
getElementType(), scalableDims);
534 if (
getType() != inferredReturnType)
535 return emitOpError() <<
"destination type " <<
getType()
536 <<
" is incompatible with source type "
537 << getSourceVectorType();
543 Type MultiDimReductionOp::getExpectedMaskType() {
544 auto vecType = getSourceVectorType();
547 vecType.getScalableDims());
556 struct ElideUnitDimsInMultiDimReduction
560 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
563 for (
const auto &dim :
enumerate(shape)) {
564 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
572 if (reductionOp.isMasked()) {
574 rootOp = reductionOp.getMaskingOp();
575 mask = reductionOp.getMaskingOp().getMask();
577 rootOp = reductionOp;
580 Location loc = reductionOp.getLoc();
581 Value acc = reductionOp.getAcc();
583 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
585 VectorType newMaskType =
587 dstVecType.getScalableDims());
588 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
590 cast = rewriter.
create<vector::ShapeCastOp>(
591 loc, reductionOp.getDestType(), reductionOp.getSource());
596 mask = rewriter.
create<vector::ExtractOp>(loc, mask);
597 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource());
602 cast,
nullptr, mask);
609 void MultiDimReductionOp::getCanonicalizationPatterns(
611 results.
add<ElideUnitDimsInMultiDimReduction>(context);
620 arith::FastMathFlags fastMathFlags) {
621 build(builder, result,
kind, vector,
Value(), fastMathFlags);
626 arith::FastMathFlags fastMathFlags) {
627 build(builder, result,
628 llvm::cast<VectorType>(vector.
getType()).getElementType(),
kind, vector,
634 int64_t rank = getSourceVectorType().getRank();
636 return emitOpError(
"unsupported reduction rank: ") << rank;
639 Type eltType = getDest().getType();
641 return emitOpError(
"unsupported reduction type '")
642 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
651 Type ReductionOp::getExpectedMaskType() {
652 auto vecType = getSourceVectorType();
655 vecType.getScalableDims());
662 case arith::AtomicRMWKind::addf:
663 case arith::AtomicRMWKind::addi:
664 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
665 CombiningKind::ADD, vector);
666 case arith::AtomicRMWKind::mulf:
667 case arith::AtomicRMWKind::muli:
668 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
669 CombiningKind::MUL, vector);
670 case arith::AtomicRMWKind::minimumf:
671 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
672 CombiningKind::MINIMUMF, vector);
673 case arith::AtomicRMWKind::mins:
674 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
675 CombiningKind::MINSI, vector);
676 case arith::AtomicRMWKind::minu:
677 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
679 case arith::AtomicRMWKind::maximumf:
680 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
681 CombiningKind::MAXIMUMF, vector);
682 case arith::AtomicRMWKind::maxs:
683 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
684 CombiningKind::MAXSI, vector);
685 case arith::AtomicRMWKind::maxu:
686 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
687 CombiningKind::MAXUI, vector);
688 case arith::AtomicRMWKind::andi:
689 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
690 CombiningKind::AND, vector);
691 case arith::AtomicRMWKind::ori:
692 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
693 CombiningKind::OR, vector);
702 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
703 return llvm::to_vector<4>(getSourceVectorType().
getShape());
710 LogicalResult matchAndRewrite(ReductionOp reductionOp,
715 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
718 if (maskableOp.isMasked()) {
720 rootOp = maskableOp.getMaskingOp();
721 mask = maskableOp.getMaskingOp().getMask();
723 rootOp = reductionOp;
726 auto vectorType = reductionOp.getSourceVectorType();
727 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
730 Location loc = reductionOp.getLoc();
732 mask = rewriter.
create<ExtractOp>(loc, mask);
733 Value result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector());
735 if (
Value acc = reductionOp.getAcc())
738 reductionOp.getFastmathAttr(), mask);
748 results.
add<ElideSingleElementReduction>(context);
762 getIndexingMapsAttrName(result.
name),
766 getIteratorTypesAttrName(result.
name),
769 return IteratorTypeAttr::get(builder.getContext(), t);
775 ArrayAttr indexingMaps,
776 ArrayAttr iteratorTypes) {
777 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
778 ContractionOp::getDefaultKind());
783 ArrayAttr indexingMaps,
784 ArrayAttr iteratorTypes, CombiningKind
kind) {
801 DictionaryAttr dictAttr;
816 dictAttr.getValue().end());
822 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
824 if (!iteratorTypes) {
826 <<
"expected " << getIteratorTypesAttrName(result.
name)
827 <<
" array attribute";
832 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
833 auto maybeIteratorType = symbolizeIteratorType(s);
834 if (!maybeIteratorType.has_value())
835 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
837 iteratorTypeAttrs.push_back(
845 getKindAttrName(result.
name),
847 ContractionOp::getDefaultKind()));
849 if (masksInfo.empty())
851 if (masksInfo.size() != 2)
853 "expected zero or exactly 2 vector mask operands");
854 auto lhsType = llvm::cast<VectorType>(types[0]);
855 auto rhsType = llvm::cast<VectorType>(types[1]);
857 std::array<VectorType, 2> maskTypes = {
867 auto attrNames = getTraitAttrNames();
869 traitAttrsSet.insert_range(attrNames);
871 for (
auto attr : (*this)->getAttrs()) {
872 if (attr.getName() == getIteratorTypesAttrName()) {
874 llvm::cast<ArrayAttr>(attr.getValue())
875 .getAsValueRange<IteratorTypeAttr, IteratorType>();
881 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
885 attrs.emplace_back(getIteratorTypesAttrName(),
887 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
888 attrs.push_back(attr);
892 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
893 p << getRhs() <<
", " << getAcc();
896 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
901 const std::vector<std::pair<int64_t, int64_t>> &map) {
902 for (
auto &dimPair : map) {
903 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
904 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
905 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
912 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
914 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
915 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
918 for (
auto &dimPair : contractingDimMap) {
919 lhsContractingDimSet.insert(dimPair.first);
920 rhsContractingDimSet.insert(dimPair.second);
923 llvm::make_second_range(batchDimMap));
927 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
928 if (lhsContractingDimSet.count(i) > 0)
930 expectedResultDims.push_back(lhsType.getDimSize(i));
934 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
935 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
937 expectedResultDims.push_back(rhsType.getDimSize(i));
941 if (expectedResultDims.empty()) {
943 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
944 return op.emitOpError(
"invalid accumulator/result vector shape");
947 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
948 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
949 if (!resVectorType || !accVectorType)
950 return op.emitOpError(
"invalid accumulator/result vector shape");
956 AffineMap lhsMap = op.getIndexingMapsArray()[0];
957 AffineMap rhsMap = op.getIndexingMapsArray()[1];
959 return op.emitOpError(
960 "expected all dimensions to be either a LHS or a RHS dimension");
963 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
964 VectorType v = pair.first;
965 auto map = pair.second;
966 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
967 unsigned pos = map.getDimPosition(idx);
972 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
973 return op.emitOpError(
"expected all dimensions to get an extent as "
974 "either a LHS or a RHS dimension");
976 AffineMap resMap = op.getIndexingMapsArray()[2];
982 llvm::IsaPred<AffineConstantExpr>) &&
983 "expected constant extent along all dimensions.");
985 auto expectedShape = llvm::to_vector<4>(
987 return cast<AffineConstantExpr>(e).getValue();
991 resVectorType.getScalableDims());
992 if (resVectorType != expected || accVectorType != expected)
993 return op.emitOpError(
994 "invalid accumulator/result vector shape, expected: ")
1001 VectorType lhsType = getLhsType();
1002 VectorType rhsType = getRhsType();
1003 Type accType = getAccType();
1004 Type resType = getResultType();
1006 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1007 if (!lhsType.getElementType().isSignlessInteger())
1008 return emitOpError(
"only supports signless integer types");
1012 if (getIndexingMapsArray().size() != 3)
1013 return emitOpError(
"expected an indexing map for each vector operand");
1018 unsigned numIterators = getIteratorTypes().getValue().size();
1020 auto index = it.index();
1021 auto map = it.value();
1022 if (map.getNumSymbols() != 0)
1023 return emitOpError(
"expected indexing map ")
1024 << index <<
" to have no symbols";
1025 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
1026 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1029 if (map.getNumDims() != numIterators)
1030 return emitOpError(
"expected indexing map ")
1031 << index <<
" to have " << numIterators <<
" number of inputs";
1032 if (map.getNumResults() != rank)
1033 return emitOpError(
"expected indexing map ")
1034 << index <<
" to have " << rank <<
" number of outputs";
1035 if (!map.isProjectedPermutation())
1036 return emitOpError(
"expected indexing map ")
1037 << index <<
" to be a projected permutation of its inputs";
1040 auto contractingDimMap = getContractingDimMap();
1041 auto batchDimMap = getBatchDimMap();
1044 if (contractingDimMap.empty())
1045 return emitOpError(
"expected at least one contracting dimension pair");
1048 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1049 return emitOpError(
"invalid contracting dimension map");
1053 return emitOpError(
"invalid batch dimension map");
1057 contractingDimMap, batchDimMap)))
1061 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1062 auto elementType = vectorType ? vectorType.getElementType() : resType;
1064 return emitOpError(
"unsupported contraction type");
1073 Type ContractionOp::getExpectedMaskType() {
1074 auto indexingMaps = this->getIndexingMapsArray();
1077 VectorType lhsType = this->getLhsType();
1078 VectorType rhsType = this->getRhsType();
1080 unsigned numVecDims = lhsIdxMap.
getNumDims();
1089 lhsType.getScalableDims()[dimIdx];
1094 rhsType.getScalableDims()[dimIdx];
1097 assert(!ShapedType::isDynamicShape(maskShape) &&
1098 "Mask shape couldn't be computed");
1102 maskShapeScalableDims);
1107 getIteratorTypesAttrName(), getKindAttrName()};
1117 static std::vector<std::pair<int64_t, int64_t>>
1119 IteratorType targetIteratorType,
MLIRContext *context) {
1120 std::vector<std::pair<int64_t, int64_t>> dimMap;
1122 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1123 if (iteratorType != targetIteratorType)
1129 if (lhsDim >= 0 && rhsDim >= 0)
1130 dimMap.emplace_back(lhsDim, rhsDim);
1135 void ContractionOp::getIterationBounds(
1137 auto lhsShape = getLhsType().getShape();
1138 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1143 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1144 if (iteratorType == IteratorType::reduction) {
1146 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1147 assert(lhsDimIndex >= 0);
1148 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1152 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1153 assert(resDimIndex >= 0);
1154 assert(resVectorType !=
nullptr);
1155 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1159 void ContractionOp::getIterationIndexMap(
1161 unsigned numMaps = getIndexingMapsArray().size();
1162 iterationIndexMap.resize(numMaps);
1164 auto index = it.index();
1165 auto map = it.value();
1166 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1167 auto dim = cast<AffineDimExpr>(map.getResult(i));
1168 iterationIndexMap[index][dim.getPosition()] = i;
1173 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1175 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1179 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1181 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1185 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1187 getIterationBounds(shape);
1209 template <
typename AddOpType>
1215 auto canonicalize = [&](
Value maybeContraction,
1216 Value otherOperand) -> vector::ContractionOp {
1217 vector::ContractionOp contractionOp =
1218 dyn_cast_or_null<vector::ContractionOp>(
1221 return vector::ContractionOp();
1222 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1223 contractionOp.getAcc().getDefiningOp())) {
1224 if (maybeZero.getValue() ==
1225 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1227 bvm.
map(contractionOp.getAcc(), otherOperand);
1228 auto newContraction =
1229 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1230 rewriter.
replaceOp(addOp, newContraction.getResult());
1231 return newContraction;
1234 return vector::ContractionOp();
1237 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1238 vector::ContractionOp
contract = canonicalize(a, b);
1240 return contract ? success() : failure();
1256 setResultRanges(getResult(), argRanges.front());
1262 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1266 VectorType vectorType = getSourceVectorType();
1267 if (vectorType.getRank() == 0) {
1269 return emitOpError(
"expected position to be empty with 0-D vector");
1272 if (vectorType.getRank() != 1)
1273 return emitOpError(
"unexpected >1 vector rank");
1275 return emitOpError(
"expected position for 1-D vector");
1279 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1281 if (!adaptor.getPosition())
1285 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1286 return splat.getInput();
1289 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1293 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1294 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1298 auto srcElements = src.getValues<
Attribute>();
1300 uint64_t posIdx = pos.getInt();
1301 if (posIdx >= srcElements.size())
1304 return srcElements[posIdx];
1311 return index == poisonValue || (index >= 0 && index < maxIndex);
1320 setResultRanges(getResult(), argRanges.front());
1325 auto vectorTy = cast<VectorType>(source.
getType());
1330 Value source, int64_t position) {
1350 build(builder, result, source, dynamicPos,
1355 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1356 ExtractOp::Adaptor adaptor,
1358 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1359 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1360 vectorType.getRank()) {
1361 inferredReturnTypes.push_back(vectorType.getElementType());
1363 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1364 vectorType.getRank());
1366 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1367 vectorType.getScalableDims().drop_front(n)));
1375 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1376 return vectorType && vectorType.getShape().equals({1}) &&
1377 vectorType.getElementType() == r.front();
1379 if (l.size() == 1 && r.size() == 1 &&
1380 (isCompatible(l, r) || isCompatible(r, l)))
1387 auto dynamicMarkersCount =
1388 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1389 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1391 "mismatch between dynamic and static positions (kDynamic marker but no "
1392 "corresponding dynamic position) -- this can only happen due to an "
1393 "incorrect fold/rewrite");
1394 auto position = getMixedPosition();
1395 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1397 "expected position attribute of rank no greater than vector rank");
1399 if (
auto attr = dyn_cast<Attribute>(pos)) {
1400 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1402 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1403 return emitOpError(
"expected position attribute #")
1405 <<
" to be a non-negative integer smaller than the "
1406 "corresponding vector dimension or poison (-1)";
1413 template <
typename IntType>
1415 return llvm::to_vector<4>(llvm::map_range(
1416 arrayAttr.getAsRange<IntegerAttr>(),
1417 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1423 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1427 if (extractOp.hasDynamicPosition())
1431 ExtractOp currentOp = extractOp;
1433 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1434 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1437 if (currentOp.hasDynamicPosition())
1440 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1442 extractOp.setOperand(0, currentOp.getVector());
1445 std::reverse(globalPosition.begin(), globalPosition.end());
1446 extractOp.setStaticPosition(globalPosition);
1458 class ExtractFromInsertTransposeChainState {
1460 ExtractFromInsertTransposeChainState(ExtractOp e);
1469 template <
typename ContainerA,
typename ContainerB>
1470 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1471 return a.size() <= b.size() &&
1472 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1479 template <
typename ContainerA,
typename ContainerB>
1480 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1481 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1482 if (elemA < 0 || elemB < 0)
1497 void updateStateForNextIteration(
Value v) {
1504 LogicalResult handleTransposeOp();
1507 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1522 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1527 Value tryToFoldExtractOpInPlace(
Value source);
1529 ExtractOp extractOp;
1531 int64_t extractedRank;
1533 InsertOp nextInsertOp;
1534 TransposeOp nextTransposeOp;
1549 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1551 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1552 extractedRank(extractOp.getNumIndices()) {
1553 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1554 sentinels.reserve(vectorRank - extractedRank);
1555 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1556 sentinels.push_back(-(i + 1));
1558 extractOp.getStaticPosition().end());
1564 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1566 if (extractOp.hasDynamicPosition())
1569 if (!nextTransposeOp)
1572 nextTransposeOp.getPermutation(), extractOp.getContext()));
1579 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1582 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1589 res = nextInsertOp.getValueToStore();
1591 return success(canFold());
1598 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1600 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1613 res = nextInsertOp.getValueToStore();
1621 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1624 if (extractOp.hasDynamicPosition())
1628 bool nothingToFold = (source == extractOp.getVector());
1629 if (nothingToFold || !canFold())
1634 extractOp.setStaticPosition(
1636 extractOp.getVectorMutable().assign(source);
1637 return extractOp.getResult();
1641 Value ExtractFromInsertTransposeChainState::fold() {
1643 if (extractOp.hasDynamicPosition())
1646 Value valueToExtractFrom = extractOp.getVector();
1647 updateStateForNextIteration(valueToExtractFrom);
1648 while (nextInsertOp || nextTransposeOp) {
1651 if (succeeded(handleTransposeOp())) {
1652 valueToExtractFrom = nextTransposeOp.getVector();
1653 updateStateForNextIteration(valueToExtractFrom);
1659 if (succeeded(handleInsertOpWithMatchingPos(result)))
1664 if (succeeded(handleInsertOpWithPrefixPos(result)))
1665 return tryToFoldExtractOpInPlace(result);
1675 valueToExtractFrom = nextInsertOp.getDest();
1676 updateStateForNextIteration(valueToExtractFrom);
1679 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1684 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1685 auto vecType = dyn_cast<VectorType>(type);
1686 return vecType && vecType.getRank() == 0;
1695 Operation *defOp = extractOp.getVector().getDefiningOp();
1696 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1700 if (extractOp.getType() == source.
getType())
1702 auto getRank = [](
Type type) {
1703 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1708 unsigned broadcastSrcRank = getRank(source.
getType());
1709 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1712 unsigned extractResultRank = getRank(extractOp.getType());
1713 if (extractResultRank > broadcastSrcRank)
1716 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1717 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1718 if (extractVecType && broadcastVecType &&
1719 extractVecType.getShape() !=
1720 broadcastVecType.getShape().take_back(extractResultRank))
1723 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1724 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1730 broadcastOp.computeBroadcastedUnitDims();
1733 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1734 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1735 if (broadcastedUnitDims.contains(i))
1739 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1740 extractPos.erase(extractPos.begin(),
1741 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1744 extractOp->setOperands(
1745 llvm::to_vector(llvm::concat<Value>(
ValueRange(source), dynPos)));
1746 extractOp.setStaticPosition(staticPos);
1747 return extractOp.getResult();
1763 if (extractOp.hasDynamicPosition())
1766 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1771 if (shuffleOp.getResultVectorType().getRank() != 1)
1774 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1775 auto shuffleMask = shuffleOp.getMask();
1776 int64_t extractIdx = extractOp.getStaticPosition()[0];
1777 int64_t shuffleIdx = shuffleMask[extractIdx];
1780 if (shuffleIdx < inputVecSize) {
1781 extractOp.setOperand(0, shuffleOp.getV1());
1782 extractOp.setStaticPosition({shuffleIdx});
1784 extractOp.setOperand(0, shuffleOp.getV2());
1785 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1788 return extractOp.getResult();
1794 if (extractOp.hasDynamicPosition())
1797 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1802 auto getDimReverse = [](VectorType type, int64_t n) {
1803 return type.getShape().take_back(n + 1).front();
1805 int64_t destinationRank =
1806 llvm::isa<VectorType>(extractOp.getType())
1807 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1809 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1811 if (destinationRank > 0) {
1812 auto destinationType =
1813 llvm::cast<VectorType>(extractOp.getResult().getType());
1814 for (int64_t i = 0; i < destinationRank; i++) {
1818 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1819 getDimReverse(destinationType, i))
1826 std::reverse(extractedPos.begin(), extractedPos.end());
1829 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1830 strides.push_back(stride);
1832 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1835 int64_t position =
linearize(extractedPos, strides);
1839 int64_t numDimension =
1840 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1842 for (int64_t i = 0; i < numDimension; i++) {
1843 newStrides.push_back(stride);
1845 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1847 std::reverse(newStrides.begin(), newStrides.end());
1851 extractOp.setStaticPosition(newPosition);
1852 extractOp.setOperand(0, shapeCastOp.getSource());
1853 return extractOp.getResult();
1859 if (extractOp.hasDynamicPosition())
1862 auto extractStridedSliceOp =
1863 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1864 if (!extractStridedSliceOp)
1873 if (extractStridedSliceOp.hasNonUnitStrides())
1878 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1879 while (!sliceOffsets.empty()) {
1880 size_t lastOffset = sliceOffsets.size() - 1;
1881 if (sliceOffsets.back() != 0 ||
1882 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1883 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1885 sliceOffsets.pop_back();
1887 unsigned destinationRank = 0;
1888 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1889 destinationRank = vecType.getRank();
1892 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1893 sliceOffsets.size())
1897 assert(extractedPos.size() >= sliceOffsets.size());
1898 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1899 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1900 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1904 extractOp.setStaticPosition(extractedPos);
1905 return extractOp.getResult();
1911 if (extractOp.hasDynamicPosition())
1914 int64_t destinationRank =
1915 llvm::isa<VectorType>(extractOp.getType())
1916 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1918 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1928 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1929 insertOp.getSourceVectorType().getRank();
1930 if (destinationRank > insertOp.getSourceVectorType().getRank())
1932 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1935 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1936 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1939 bool disjoint =
false;
1941 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1942 int64_t start = insertOffsets[dim];
1944 (dim < insertRankDiff)
1946 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1947 int64_t end = start + size;
1948 int64_t offset = extractOffsets[dim];
1950 if (start <= offset && offset < end) {
1951 if (dim >= insertRankDiff)
1952 offsetDiffs.push_back(offset - start);
1962 int64_t srcRankDiff =
1963 insertOp.getSourceVectorType().getRank() - destinationRank;
1964 for (int64_t i = 0; i < destinationRank; i++) {
1965 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1966 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1970 extractOp.getVectorMutable().assign(insertOp.getValueToStore());
1973 extractOp.setStaticPosition(offsetDiffs);
1974 return extractOp.getResult();
1978 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1991 if (extractOp.hasDynamicPosition())
1995 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1996 if (!fromElementsOp)
2000 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2001 if (vecType.isScalable())
2005 int64_t rank = vecType.getRank();
2007 if (extractOp.getType() != vecType.getElementType())
2009 assert(
static_cast<int64_t
>(indices.size()) == rank &&
2010 "unexpected number of indices");
2015 for (
int i = rank - 1; i >= 0; --i) {
2016 flatIndex += indices[i] * stride;
2017 stride *= vecType.getDimSize(i);
2019 return fromElementsOp.getElements()[flatIndex];
2024 template <
typename OpType,
typename AdaptorType>
2027 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2028 OperandRange dynamicPosition = op.getDynamicPosition();
2031 if constexpr (std::is_same_v<OpType, ExtractOp>)
2032 vectorShape = op.getSourceVectorType().getShape();
2037 if (!dynamicPosition.size())
2044 bool opChange =
false;
2045 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2046 if (!ShapedType::isDynamic(staticPosition[i]))
2048 Attribute positionAttr = dynamicPositionAttr[index];
2049 Value position = dynamicPosition[index++];
2050 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2051 int64_t value = attr.getInt();
2055 staticPosition[i] = attr.getInt();
2060 operands.push_back(position);
2064 op.setStaticPosition(staticPosition);
2065 op.getOperation()->setOperands(operands);
2067 return op.getResult();
2076 int64_t poisonVal) {
2077 if (!is_contained(staticPos, poisonVal))
2085 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2094 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2099 if (denseAttr.isSplat()) {
2101 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2106 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2107 if (vecTy.isScalable())
2110 if (extractOp.hasDynamicPosition()) {
2125 copy(extractOp.getStaticPosition(), completePositions.begin());
2128 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2131 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2133 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2136 newAttr = *denseValuesBegin;
2146 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
2157 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2163 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2178 return inplaceFolded;
2190 Operation *defOp = extractOp.getVector().getDefiningOp();
2191 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2195 if (extractOp.getType() == source.
getType())
2197 auto getRank = [](
Type type) {
2198 return llvm::isa<VectorType>(type)
2199 ? llvm::cast<VectorType>(type).getRank()
2202 unsigned broadcastSrcRank = getRank(source.
getType());
2203 unsigned extractResultRank = getRank(extractOp.getType());
2207 if (extractResultRank < broadcastSrcRank)
2211 if (extractResultRank == 0)
2215 extractOp, extractOp.getType(), source);
2228 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2232 VectorType extractedMaskType =
2233 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2235 if (!extractedMaskType)
2238 auto maskOperands = createMaskOp.getOperands();
2240 VectorType maskType = createMaskOp.getVectorType();
2242 bool containsUnknownDims =
false;
2245 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2247 int64_t pos = extractOpPos[dimIdx];
2248 Value operand = maskOperands[dimIdx];
2249 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2252 containsUnknownDims =
true;
2256 int64_t createMaskBound =
2257 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2259 if (pos != ShapedType::kDynamic) {
2262 allFalse |= pos >= createMaskBound;
2263 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2267 containsUnknownDims =
true;
2274 }
else if (!containsUnknownDims) {
2276 extractOp, extractedMaskType,
2277 maskOperands.drop_front(extractOpPos.size()));
2287 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2289 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2293 VectorType sourceType = castOp.getSourceVectorType();
2294 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2298 if (sourceType.getNumElements() != targetType.getNumElements())
2302 castOp.getSource());
2312 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2315 if (extractOp.hasDynamicPosition())
2319 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2324 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2325 if (!fromElementsOp)
2327 VectorType inputType = fromElementsOp.getType();
2330 if (resultType.isScalable() || inputType.isScalable())
2336 llvm::to_vector(extractOp.getStaticPosition());
2337 firstElementPos.append(resultType.getRank(), 0);
2340 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2341 flatIndex += firstElementPos[i] * stride;
2342 stride *= inputType.getDimSize(i);
2347 extractOp, resultType,
2348 fromElementsOp.getElements().slice(flatIndex,
2349 resultType.getNumElements()));
2357 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2358 results.
add(foldExtractFromShapeCastToShapeCast);
2359 results.
add(foldExtractFromFromElements);
2364 for (
auto attr : arrayAttr)
2365 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2372 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2383 if (operands.empty())
2386 return llvm::all_of(operands, [&](
Value operand) {
2388 return currentDef == defOp;
2403 static LogicalResult
2406 auto fromElementsOp =
2407 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2408 if (!fromElementsOp)
2411 llvm::append_range(results, fromElementsOp.getElements());
2415 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2437 OperandRange fromElemsOperands = fromElementsOp.getElements();
2438 if (fromElemsOperands.empty())
2441 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2449 Value toElementsInput = toElementsOp.getSource();
2450 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2451 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2452 return toElementsInput;
2461 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2472 if (!llvm::all_equal(fromElementsOp.getElements()))
2475 fromElementsOp.getElements().front());
2501 using OpRewritePattern::OpRewritePattern;
2503 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2507 if (fromElements.getType().getNumElements() == 1)
2518 for (
auto [insertIndex, element] :
2523 dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
2526 "element not from vector.extract");
2531 if (insertIndex == 0) {
2532 source = extractOp.getVector();
2533 }
else if (extractOp.getVector() != source) {
2535 "element from different vector");
2539 int64_t rank = position.size();
2540 assert(rank == source.getType().getRank() &&
2541 "scalar extract must have full rank position");
2552 if (insertIndex == 0) {
2553 const int64_t numElms = fromElements.getType().getNumElements();
2554 int64_t numSuffixElms = 1;
2555 int64_t index = rank;
2556 while (index > 0 && position[index - 1] == 0 &&
2557 numSuffixElms < numElms) {
2558 numSuffixElms *= source.getType().getDimSize(index - 1);
2561 if (numSuffixElms != numElms) {
2563 fromElements,
"elements do not form a suffix of source");
2565 expectedPosition = llvm::to_vector(position);
2566 combinedPosition = position.drop_back(rank - index);
2570 else if (expectedPosition != position) {
2572 fromElements,
"elements not in ascending order (static order)");
2574 increment(expectedPosition, source.getType().getShape());
2577 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2578 fromElements.getLoc(), source, combinedPosition);
2581 fromElements, fromElements.getType(), extracted);
2589 for (
int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2591 if (indices[dim] < shape[dim])
2610 setResultRanges(getResult(), argRanges.front());
2613 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2614 return llvm::to_vector<4>(getResultVectorType().
getShape());
2622 int64_t rankDiff = dstShape.size() - srcShape.size();
2623 int64_t dstDim = rankDiff;
2625 for (
auto [s1, s2] :
2626 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2628 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2638 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2657 Value BroadcastOp::createOrFoldBroadcastOp(
2660 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2664 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2665 if (broadcastedDims.contains(i))
2667 checkShape.push_back(dstShape[i]);
2669 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2670 "ill-formed broadcastedDims contains values not confined to "
2675 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2679 if (!srcVectorType) {
2680 assert(checkShape.empty() &&
2681 "ill-formed createOrFoldBroadcastOp arguments");
2682 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2685 assert(srcVectorType.getShape().equals(checkShape) &&
2686 "ill-formed createOrFoldBroadcastOp arguments");
2697 broadcastShape.reserve(dstShape.size());
2713 int64_t nextSrcShapeDim = broadcastedDims.size();
2714 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2715 if (broadcastedDims.contains(i)) {
2720 broadcastShape.push_back(dstShape[i]);
2721 permutation[i] = broadcastShape.size() - 1;
2727 permutation[i] = nextSrcShapeDim++;
2731 llvm::append_range(broadcastShape, srcVectorType.getShape());
2736 "unexpected \"dim-1\" broadcast");
2738 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2740 vector::BroadcastableToResult::Success &&
2741 "must be broadcastable");
2745 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2746 if (permutation[i] != i)
2747 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2753 Type srcType, VectorType dstVectorType,
2754 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2758 return BroadcastableToResult::Success;
2760 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2762 return BroadcastableToResult::SourceTypeNotAVector;
2764 int64_t srcRank = srcVectorType.getRank();
2765 int64_t dstRank = dstVectorType.getRank();
2766 if (srcRank > dstRank)
2767 return BroadcastableToResult::SourceRankHigher;
2770 int64_t lead = dstRank - srcRank;
2771 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2774 bool foundMismatchingDims =
false;
2777 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2778 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2779 if (srcDim != 1 && srcDim != dstDim)
2780 foundMismatchingDims =
true;
2783 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2784 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2785 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2788 (srcDimScalableFlag != dstDimScalableFlag &&
2789 (srcDim != 1 || srcDimScalableFlag)))
2790 foundMismatchingDims =
true;
2792 if (foundMismatchingDims) {
2793 if (mismatchingDims !=
nullptr) {
2794 mismatchingDims->first.dim = srcDim;
2795 mismatchingDims->first.isScalable = srcDimScalableFlag;
2797 mismatchingDims->second.dim = dstDim;
2798 mismatchingDims->second.isScalable = dstDimScalableFlag;
2800 return BroadcastableToResult::DimensionMismatch;
2804 return BroadcastableToResult::Success;
2808 std::pair<VectorDim, VectorDim> mismatchingDims;
2810 getSourceType(), getResultVectorType(), &mismatchingDims);
2811 if (res == BroadcastableToResult::Success)
2813 if (res == BroadcastableToResult::SourceRankHigher)
2814 return emitOpError(
"source rank higher than destination rank");
2815 if (res == BroadcastableToResult::DimensionMismatch) {
2816 return emitOpError(
"dimension mismatch (")
2817 << (mismatchingDims.first.isScalable ?
"[" :
"")
2818 << mismatchingDims.first.dim
2819 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2820 << (mismatchingDims.second.isScalable ?
"[" :
"")
2821 << mismatchingDims.second.dim
2822 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2824 if (res == BroadcastableToResult::SourceTypeNotAVector)
2825 return emitOpError(
"source type is not a vector");
2826 llvm_unreachable(
"unexpected vector.broadcast op error");
2830 if (getSourceType() == getResultVectorType())
2832 if (!adaptor.getSource())
2834 auto vectorType = getResultVectorType();
2835 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2836 if (vectorType.getElementType() != attr.getType())
2840 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2841 if (vectorType.getElementType() != attr.getType())
2845 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2847 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2860 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2864 broadcastOp.getResultVectorType(),
2865 srcBroadcast.getSource());
2875 results.
add<BroadcastFolder>(context);
2883 VectorType resultType = getResultVectorType();
2884 VectorType v1Type = getV1VectorType();
2885 VectorType v2Type = getV2VectorType();
2887 int64_t resRank = resultType.getRank();
2888 int64_t v1Rank = v1Type.getRank();
2889 int64_t v2Rank = v2Type.getRank();
2890 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2891 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2892 if (!wellFormed0DCase && !wellFormedNDCase)
2893 return emitOpError(
"rank mismatch");
2896 for (int64_t r = 1; r < v1Rank; ++r) {
2897 int64_t resDim = resultType.getDimSize(r);
2898 int64_t v1Dim = v1Type.getDimSize(r);
2899 int64_t v2Dim = v2Type.getDimSize(r);
2900 if (resDim != v1Dim || v1Dim != v2Dim)
2901 return emitOpError(
"dimension mismatch");
2905 int64_t maskLength = mask.size();
2906 if (maskLength <= 0)
2907 return emitOpError(
"invalid mask length");
2908 if (maskLength != resultType.getDimSize(0))
2909 return emitOpError(
"mask length mismatch");
2911 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2912 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2915 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2921 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2922 ShuffleOp::Adaptor adaptor,
2924 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2925 auto v1Rank = v1Type.getRank();
2929 shape.reserve(v1Rank);
2930 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2933 llvm::append_range(shape, v1Type.getShape().drop_front());
2934 inferredReturnTypes.push_back(
2939 template <
typename T>
2942 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2943 return value == expected++;
2947 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2948 auto v1Type = getV1VectorType();
2949 auto v2Type = getV2VectorType();
2951 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2952 "Vector shuffle does not support scalable vectors");
2956 if (v1Type.getRank() == 0)
2960 auto mask = getMask();
2967 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2968 if (!v1Attr || !v2Attr)
2972 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2973 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2974 if (isV1Poison && isV2Poison)
2979 if (v1Type.getRank() != 1)
2989 to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2990 poisonElement = v2Elements[0];
2994 to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2995 poisonElement = v1Elements[0];
2999 int64_t v1Size = v1Type.getDimSize(0);
3000 for (int64_t maskIdx : mask) {
3003 if (maskIdx == ShuffleOp::kPoisonIndex) {
3004 indexedElm = poisonElement;
3006 if (maskIdx < v1Size)
3007 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3009 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3012 results.push_back(indexedElm);
3027 VectorType v1VectorType = shuffleOp.getV1VectorType();
3029 if (v1VectorType.getRank() > 0)
3031 if (mask.size() != 1)
3051 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
3052 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
3054 if (!v1Splat || !v2Splat)
3057 if (v1Splat.getInput() != v2Splat.getInput())
3073 VectorType resultType = op.getResultVectorType();
3074 if (resultType.isScalable())
3076 op,
"ShuffleOp can't represent a scalable interleave");
3078 if (resultType.getRank() != 1)
3080 op,
"ShuffleOp can't represent an n-D interleave");
3082 VectorType sourceType = op.getV1VectorType();
3083 if (sourceType != op.getV2VectorType() ||
3084 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3086 op,
"ShuffleOp types don't match an interleave");
3090 int64_t resultVectorSize = resultType.getNumElements();
3091 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3092 int64_t maskValueA = shuffleMask[i * 2];
3093 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3094 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3096 "ShuffleOp mask not interleaving");
3108 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3118 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3123 build(builder, result, source, dest, {});
3127 auto dstVectorType = getDestVectorType();
3128 if (dstVectorType.getRank() == 0) {
3130 return emitOpError(
"expected position to be empty with 0-D vector");
3133 if (dstVectorType.getRank() != 1)
3134 return emitOpError(
"unexpected >1 vector rank");
3136 return emitOpError(
"expected position for 1-D vector");
3140 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
3142 if (!adaptor.getPosition())
3145 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
3146 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
3147 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
3148 if (!src || !dst || !pos)
3154 auto dstElements = dst.getValues<
Attribute>();
3158 uint64_t posIdx = pos.getInt();
3159 if (posIdx >= results.size())
3161 results[posIdx] = src;
3172 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3177 auto vectorTy = cast<VectorType>(dest.
getType());
3178 build(builder, result, source, dest,
3183 Value source,
Value dest, int64_t position) {
3196 posVals.reserve(position.size());
3197 llvm::transform(position, std::back_inserter(posVals),
3199 build(builder, result, source, dest, posVals);
3208 build(builder, result, source, dest, dynamicPos,
3214 auto destVectorType = getDestVectorType();
3215 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3217 "expected position attribute of rank no greater than dest vector rank");
3218 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3219 if (srcVectorType &&
3220 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3221 static_cast<unsigned>(destVectorType.getRank())))
3222 return emitOpError(
"expected position attribute rank + source rank to "
3223 "match dest vector rank");
3224 if (!srcVectorType &&
3225 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3227 "expected position attribute rank to match the dest vector rank");
3229 if (
auto attr = dyn_cast<Attribute>(pos)) {
3230 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3232 destVectorType.getDimSize(idx))) {
3233 return emitOpError(
"expected position attribute #")
3235 <<
" to be a non-negative integer smaller than the "
3237 "dest vector dimension";
3255 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3256 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3257 srcVecType.getNumElements())
3260 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3272 auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3273 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3275 if (!srcSplat || !dstSplat)
3278 if (srcSplat.getInput() != dstSplat.getInput())
3291 int64_t maxVectorSizeFoldThreshold) {
3292 if (insertOp.hasDynamicPosition())
3295 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3303 VectorType destTy = insertOp.getDestVectorType();
3304 if (destTy.isScalable())
3308 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3309 !insertOp->hasOneUse())
3315 copy(insertOp.getStaticPosition(), completePositions.begin());
3316 int64_t insertBeginPosition =
3320 Type destEltType = destTy.getElementType();
3325 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3326 if (intAttr.getType() != expectedType)
3335 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3336 for (
auto value : denseSource.getValues<
Attribute>())
3342 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3343 copy(insertedValues, allValues.begin() + insertBeginPosition);
3351 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3354 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3357 constexpr int64_t vectorSizeFoldThreshold = 256;
3361 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3362 return getValueToStore();
3370 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3373 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3374 vectorSizeFoldThreshold)) {
3378 return inplaceFolded;
3400 template <
typename OpType>
3402 ArrayAttr arrayAttr,
3404 StringRef attrName) {
3405 if (arrayAttr.size() > shape.size())
3406 return op.emitOpError(
"expected ")
3407 << attrName <<
" attribute of rank no greater than vector rank";
3414 template <
typename OpType>
3415 static LogicalResult
3417 int64_t
max, StringRef attrName,
3418 bool halfOpen =
true) {
3419 for (
auto attr : arrayAttr) {
3420 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3424 if (val < min || val >= upper)
3425 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3426 <<
min <<
", " << upper <<
")";
3434 template <
typename OpType>
3435 static LogicalResult
3438 bool halfOpen =
true, int64_t
min = 0) {
3439 for (
auto [index, attrDimPair] :
3441 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3442 int64_t
max = std::get<1>(attrDimPair);
3445 if (val < min || val >=
max)
3446 return op.emitOpError(
"expected ")
3447 << attrName <<
" dimension " << index <<
" to be confined to ["
3448 <<
min <<
", " <<
max <<
")";
3458 template <
typename OpType>
3460 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3462 bool halfOpen =
true, int64_t
min = 1) {
3463 assert(arrayAttr1.size() <= shape.size());
3464 assert(arrayAttr2.size() <= shape.size());
3465 for (
auto [index, it] :
3467 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3468 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3469 int64_t
max = std::get<2>(it);
3472 if (val1 + val2 < 0 || val1 + val2 >=
max)
3473 return op.emitOpError(
"expected sum(")
3474 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3475 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3482 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3489 auto sourceVectorType = getSourceVectorType();
3490 auto destVectorType = getDestVectorType();
3491 auto offsets = getOffsetsAttr();
3492 auto strides = getStridesAttr();
3493 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3495 "expected offsets of same size as destination vector rank");
3496 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3497 return emitOpError(
"expected strides of same size as source vector rank");
3498 if (sourceVectorType.getRank() > destVectorType.getRank())
3500 "expected source rank to be no greater than destination rank");
3502 auto sourceShape = sourceVectorType.getShape();
3503 auto destShape = destVectorType.getShape();
3505 destShape.size() - sourceShape.size(), 0);
3506 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3507 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3508 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3517 offName,
"source vector shape",
3521 unsigned rankDiff = destShape.size() - sourceShape.size();
3522 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3523 if (sourceVectorType.getScalableDims()[idx] !=
3524 destVectorType.getScalableDims()[idx + rankDiff]) {
3525 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3528 if (sourceVectorType.getScalableDims()[idx]) {
3529 auto sourceSize = sourceShape[idx];
3530 auto destSize = destShape[idx + rankDiff];
3531 if (sourceSize != destSize) {
3532 return emitOpError(
"expected size at idx=")
3534 << (
" to match the corresponding base size from the input "
3536 << sourceSize << (
" vs ") << destSize << (
")");
3547 class FoldInsertStridedSliceSplat final
3552 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3555 insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3557 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3559 if (!srcSplatOp || !destSplatOp)
3562 if (srcSplatOp.getInput() != destSplatOp.getInput())
3565 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3572 class FoldInsertStridedSliceOfExtract final
3577 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3579 auto extractStridedSliceOp =
3580 insertStridedSliceOp.getValueToStore()
3581 .getDefiningOp<vector::ExtractStridedSliceOp>();
3583 if (!extractStridedSliceOp)
3586 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3590 if (extractStridedSliceOp.getStrides() !=
3591 insertStridedSliceOp.getStrides() ||
3592 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3595 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3602 class InsertStridedSliceConstantFolder final
3609 static constexpr int64_t vectorSizeFoldThreshold = 256;
3620 VectorType destTy = destVector.getType();
3621 if (destTy.isScalable())
3625 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3626 !destVector.hasOneUse())
3635 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3639 if (op.hasNonUnitStrides())
3642 VectorType sliceVecTy = sourceValue.getType();
3644 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3654 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3655 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3656 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3657 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3660 currDestPosition.begin() + rankDifference, currDestPosition.end());
3664 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3665 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3666 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3667 "Invalid slice element");
3668 newValues[linearizedPosition] = *sliceValuesIt;
3681 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3683 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3684 InsertStridedSliceConstantFolder>(context);
3687 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3688 if (getSourceVectorType() == getDestVectorType())
3689 return getValueToStore();
3705 p <<
" " << getLhs() <<
", " << getRhs();
3707 p <<
", " << getAcc();
3710 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3721 if (operandsInfo.size() < 2)
3723 "expected at least 2 operands");
3724 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3725 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3728 "expected vector type for operand #1");
3733 vRHS.getScalableDims()[0]};
3735 vLHS.getElementType(), scalableDimsRes);
3739 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3745 OuterProductOp::getKindAttrName(result.
name),
3747 OuterProductOp::getDefaultKind()));
3753 (operandsInfo.size() > 2 &&
3759 Type tRHS = getOperandTypeRHS();
3760 VectorType vLHS = getOperandVectorTypeLHS(),
3761 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3762 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3764 if (vLHS.getRank() != 1)
3765 return emitOpError(
"expected 1-d vector for operand #1");
3769 if (vRHS.getRank() != 1)
3770 return emitOpError(
"expected 1-d vector for operand #2");
3771 if (vRES.getRank() != 2)
3772 return emitOpError(
"expected 2-d vector result");
3773 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3774 return emitOpError(
"expected #1 operand dim to match result dim #1");
3775 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3776 return emitOpError(
"expected #2 operand dim to match result dim #2");
3777 if (vLHS.isScalable() && !vRHS.isScalable()) {
3781 "expected either both or only #2 operand dim to be scalable");
3785 if (vRES.getRank() != 1)
3786 return emitOpError(
"expected 1-d vector result");
3787 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3788 return emitOpError(
"expected #1 operand dim to match result dim #1");
3791 if (vACC && vACC != vRES)
3792 return emitOpError(
"expected operand #3 of same type as result type");
3796 return emitOpError(
"unsupported outerproduct type");
3805 Type OuterProductOp::getExpectedMaskType() {
3806 auto vecType = this->getResultVectorType();
3809 vecType.getScalableDims());
3821 ArrayAttr offsets, ArrayAttr sizes,
3822 ArrayAttr strides) {
3823 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3825 shape.reserve(vectorType.getRank());
3827 for (
unsigned e = offsets.size(); idx < e; ++idx)
3828 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3829 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3830 shape.push_back(vectorType.getShape()[idx]);
3833 vectorType.getScalableDims());
3846 offsetsAttr, sizesAttr, stridesAttr));
3847 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3851 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3856 auto type = getSourceVectorType();
3857 auto offsets = getOffsetsAttr();
3858 auto sizes = getSizesAttr();
3859 auto strides = getStridesAttr();
3860 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3862 "expected offsets, sizes and strides attributes of same size");
3864 auto shape = type.getShape();
3865 auto offName = getOffsetsAttrName();
3866 auto sizesName = getSizesAttrName();
3867 auto stridesName = getStridesAttrName();
3883 shape, offName, sizesName,
3888 offsets, sizes, strides);
3889 if (getResult().
getType() != resultType)
3890 return emitOpError(
"expected result type to be ") << resultType;
3892 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3893 if (type.getScalableDims()[idx]) {
3894 auto inputDim = type.getShape()[idx];
3895 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3896 if (inputDim != inputSize)
3897 return emitOpError(
"expected size at idx=")
3899 << (
" to match the corresponding base size from the input "
3901 << inputSize << (
" vs ") << inputDim << (
")");
3911 static LogicalResult
3914 auto getElement = [](ArrayAttr array,
int idx) {
3915 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3917 ArrayAttr extractOffsets = op.getOffsets();
3919 ArrayAttr extractSizes = op.getSizes();
3920 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3922 if (op.getSourceVectorType().getRank() !=
3923 insertOp.getSourceVectorType().getRank())
3925 ArrayAttr insertOffsets = insertOp.getOffsets();
3926 ArrayAttr insertStrides = insertOp.getStrides();
3929 if (extractOffsets.size() > insertOffsets.size())
3931 bool patialoverlap =
false;
3932 bool disjoint =
false;
3934 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3935 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3937 int64_t start = getElement(insertOffsets, dim);
3938 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3939 int64_t offset = getElement(extractOffsets, dim);
3940 int64_t size = getElement(extractSizes, dim);
3942 if (start <= offset && offset < end) {
3945 if (offset + size > end)
3946 patialoverlap =
true;
3947 offsetDiffs.push_back(offset - start);
3954 if (!disjoint && !patialoverlap) {
3955 op.setOperand(insertOp.getValueToStore());
3964 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3979 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
3984 if (op.hasNonUnitStrides())
3987 VectorType sourceVecTy = op.getSourceVectorType();
3991 VectorType sliceVecTy = op.getType();
3993 int64_t rank = sliceVecTy.getRank();
4005 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4007 sliceValues.reserve(sliceVecTy.getNumElements());
4010 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
4011 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4013 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4014 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4016 assert(
static_cast<int64_t
>(sliceValues.size()) ==
4017 sliceVecTy.getNumElements() &&
4018 "Invalid number of slice elements");
4022 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4023 if (getSourceVectorType() == getResult().
getType())
4030 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
4045 class StridedSliceConstantMaskFolder final
4050 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4054 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
4055 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4056 if (!constantMaskOp)
4059 if (extractStridedSliceOp.hasNonUnitStrides())
4072 sliceMaskDimSizes.reserve(maskDimSizes.size());
4073 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4074 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4075 int64_t sliceMaskDimSize =
std::max(
4076 static_cast<int64_t
>(0),
4077 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4078 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4081 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4082 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4083 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4086 if (llvm::is_contained(sliceMaskDimSizes, 0))
4087 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4092 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4100 class StridedSliceBroadcast final
4112 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4113 auto dstVecType = llvm::cast<VectorType>(op.getType());
4114 unsigned dstRank = dstVecType.getRank();
4115 unsigned rankDiff = dstRank - srcRank;
4119 bool lowerDimMatch =
true;
4120 for (
unsigned i = 0; i < srcRank; i++) {
4121 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4122 lowerDimMatch =
false;
4131 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
4132 if (!lowerDimMatch && !isScalarSrc) {
4133 source = rewriter.
create<ExtractStridedSliceOp>(
4134 op->getLoc(), source,
4145 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4151 auto splat = op.getVector().getDefiningOp<SplatOp>();
4175 class ContiguousExtractStridedSliceToExtract final
4182 if (op.hasNonUnitStrides())
4184 Value source = op.getOperand();
4185 auto sourceType = cast<VectorType>(source.
getType());
4186 if (sourceType.isScalable() || sourceType.getRank() == 0)
4195 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4196 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4203 if (numOffsets == 0)
4208 if (numOffsets == sourceType.getRank() &&
4209 static_cast<int>(sizes.size()) == sourceType.getRank())
4213 for (
int i = 0; i < numOffsets; ++i) {
4221 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4222 sizes[numOffsets] == 1) {
4227 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4228 Value extract = rewriter.
create<vector::ExtractOp>(op->getLoc(), source,
4237 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4241 results.
add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4242 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4252 VectorType vectorType,
Value source,
4253 ValueRange indices, AffineMapAttr permutationMapAttr,
4254 ArrayAttr inBoundsAttr) {
4255 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4256 Value padding = builder.
create<arith::ConstantOp>(
4258 build(builder, result, vectorType, source, indices, permutationMapAttr,
4259 padding,
Value(), inBoundsAttr);
4264 VectorType vectorType,
Value source,
4268 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4272 build(builder, result, vectorType, source, indices, permutationMapAttr,
4278 VectorType vectorType,
Value source,
4282 llvm::cast<ShapedType>(source.
getType()), vectorType);
4284 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4288 build(builder, result, vectorType, source, indices, permutationMapAttr,
4290 Value(), inBoundsAttr);
4296 VectorType vectorType,
Value source,
4299 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4300 Value padding = builder.
create<arith::ConstantOp>(
4302 build(builder, result, vectorType, source, indices, padding, inBounds);
4305 template <
typename EmitFun>
4307 EmitFun emitOpError) {
4309 for (
auto expr : permutationMap.
getResults()) {
4310 auto dim = dyn_cast<AffineDimExpr>(expr);
4311 auto zero = dyn_cast<AffineConstantExpr>(expr);
4313 if (zero.getValue() != 0) {
4315 "requires a projected permutation_map (at most one dim or the zero "
4316 "constant can appear in each result)");
4321 return emitOpError(
"requires a projected permutation_map (at most one "
4322 "dim or the zero constant can appear in each result)");
4324 if (seen[dim.getPosition()]) {
4326 "requires a permutation_map that is a permutation (found one dim "
4327 "used more than once)");
4329 seen[dim.getPosition()] =
true;
4334 static LogicalResult
4336 VectorType vectorType, VectorType maskType,
4337 VectorType inferredMaskType,
AffineMap permutationMap,
4338 ArrayAttr inBounds) {
4339 if (op->hasAttr(
"masked")) {
4340 return op->emitOpError(
"masked attribute has been removed. "
4341 "Use in_bounds instead.");
4344 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4345 return op->emitOpError(
4346 "requires source to be a memref or ranked tensor type");
4348 auto elementType = shapedType.getElementType();
4349 DataLayout dataLayout = DataLayout::closest(op);
4350 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4352 unsigned sourceVecSize =
4354 vectorElementType.getShape().back();
4355 unsigned resultVecSize =
4357 vectorType.getShape().back();
4358 if (resultVecSize % sourceVecSize != 0)
4359 return op->emitOpError(
4360 "requires the bitwidth of the minor 1-D vector to be an integral "
4361 "multiple of the bitwidth of the minor 1-D vector of the source");
4363 unsigned sourceVecEltRank = vectorElementType.getRank();
4364 unsigned resultVecRank = vectorType.getRank();
4365 if (sourceVecEltRank > resultVecRank)
4366 return op->emitOpError(
4367 "requires source vector element and vector result ranks to match.");
4368 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4371 return op->emitOpError(
"requires a permutation_map with result dims of "
4372 "the same rank as the vector type");
4375 return op->emitOpError(
"does not support masks with vector element type");
4378 unsigned minorSize =
4379 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4380 unsigned resultVecSize =
4383 return op->emitOpError(
4384 "requires the bitwidth of the minor 1-D vector to be an integral "
4385 "multiple of the bitwidth of the source element type");
4389 return op->emitOpError(
"requires a permutation_map with result dims of "
4390 "the same rank as the vector type");
4394 return op->emitOpError(
"requires permutation_map without symbols");
4396 if (permutationMap.
getNumInputs() != shapedType.getRank())
4397 return op->emitOpError(
"requires a permutation_map with input dims of the "
4398 "same rank as the source type");
4400 if (maskType && maskType != inferredMaskType)
4401 return op->emitOpError(
"inferred mask type (")
4402 << inferredMaskType <<
") and mask operand type (" << maskType
4405 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4406 return op->emitOpError(
"expects the in_bounds attr of same rank "
4407 "as permutation_map results: ")
4409 <<
" vs inBounds of size: " << inBounds.size();
4416 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4417 if (op.getPermutationMap().isMinorIdentity())
4418 elidedAttrs.push_back(op.getPermutationMapAttrName());
4420 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4421 elidedAttrs.push_back(op.getInBoundsAttrName());
4428 p <<
", " << getMask();
4437 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4442 if (maskShape.empty())
4443 maskShape.push_back(1);
4465 if (hasMask.succeeded()) {
4472 if (types.size() != 2)
4473 return parser.
emitError(typesLoc,
"requires two types");
4475 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4476 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4477 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4478 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4480 return parser.
emitError(typesLoc,
"requires vector type");
4481 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4485 if (shapedType.getRank() <
4488 "expected a custom permutation_map when "
4489 "rank(source) != rank(destination)");
4493 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4495 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4497 if (!inBoundsAttr) {
4507 if (hasMask.succeeded()) {
4508 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4510 maskInfo.
location,
"does not support masks with vector element type");
4513 "expected the same rank for the vector and the "
4514 "results of the permutation map");
4522 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4524 {1, static_cast<int32_t>(indexInfo.size()), 1,
4525 static_cast<int32_t>(hasMask.succeeded())}));
4531 ShapedType shapedType = getShapedType();
4533 VectorType maskType = getMaskType();
4534 auto paddingType = getPadding().getType();
4535 auto permutationMap = getPermutationMap();
4536 VectorType inferredMaskType =
4539 auto sourceElementType = shapedType.getElementType();
4541 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4542 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4544 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4545 shapedType, vectorType, maskType,
4546 inferredMaskType, permutationMap, getInBounds())))
4549 if (
auto sourceVectorElementType =
4550 llvm::dyn_cast<VectorType>(sourceElementType)) {
4553 if (sourceVectorElementType != paddingType)
4555 "requires source element type and padding type to match.");
4559 if (!VectorType::isValidElementType(paddingType))
4560 return emitOpError(
"requires valid padding vector elemental type");
4563 if (paddingType != sourceElementType)
4565 "requires formal padding and source of the same elemental type");
4569 [&](Twine t) {
return emitOpError(t); });
4576 Type TransferReadOp::getExpectedMaskType() {
4584 return cast<VectorType>(getVector().
getType());
4587 template <
typename TransferOp>
4588 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4591 if (op.getShapedType().isDynamicDim(indicesIdx))
4593 Value index = op.getIndices()[indicesIdx];
4595 if (!cstOp.has_value())
4598 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4599 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4601 return cstOp.value() + vectorSize <= sourceSize;
4604 template <
typename TransferOp>
4608 if (op.getTransferRank() == 0)
4613 newInBounds.reserve(op.getTransferRank());
4618 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4620 if (op.isDimInBounds(i)) {
4621 newInBounds.push_back(
true);
4626 bool inBounds =
false;
4627 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4630 dimExpr.getPosition());
4631 nonBcastDims.push_back(i);
4634 newInBounds.push_back(inBounds);
4642 bool allNonBcastDimsInBounds = llvm::all_of(
4643 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4644 if (allNonBcastDimsInBounds) {
4647 newInBounds[idx] =
true;
4659 template <
typename TransferOp>
4661 auto mask = op.getMask();
4668 op.getMaskMutable().clear();
4682 static Value foldRAW(TransferReadOp readOp) {
4683 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4685 auto defWrite = readOp.getBase().
getDefiningOp<vector::TransferWriteOp>();
4688 return defWrite.getVector();
4690 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4691 cast<VectorTransferOpInterface>(readOp.getOperation())))
4693 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4699 if (
Value vec = foldRAW(*
this))
4713 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4717 void TransferReadOp::getEffects(
4720 if (llvm::isa<MemRefType>(getShapedType()))
4726 if (hasPureTensorSemantics())
4754 struct TransferReadAfterWriteToBroadcast
4760 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4764 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
4767 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
4771 if (readOp.getTransferChunkAccessed() !=
4772 defWrite.getTransferChunkAccessed())
4791 if (readOp.getMask() || defWrite.getMask())
4794 if (readOp.getIndices() != defWrite.getIndices())
4797 Value vec = defWrite.getVector();
4817 broadcastShape[pos.value()] = destShape[pos.index()];
4818 broadcastScalableFlags[pos.value()] =
4819 readOp.getVectorType().getScalableDims()[pos.index()];
4822 broadcastShape, defWrite.getVectorType().getElementType(),
4823 broadcastScalableFlags);
4824 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4835 results.
add<TransferReadAfterWriteToBroadcast>(context);
4845 AffineMapAttr permutationMapAttr,
4847 ArrayAttr inBoundsAttr) {
4848 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4849 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4850 mask, inBoundsAttr);
4856 AffineMapAttr permutationMapAttr,
4857 ArrayAttr inBoundsAttr) {
4858 build(builder, result, vector, dest, indices, permutationMapAttr,
4859 Value(), inBoundsAttr);
4870 (inBounds && !inBounds.value().empty())
4873 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4874 build(builder, result, vector, dest, indices, permutationMapAttr,
4875 Value(), inBoundsAttr);
4883 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4885 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4886 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4902 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4907 if (types.size() != 2)
4908 return parser.
emitError(typesLoc,
"requires two types");
4910 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4912 return parser.
emitError(typesLoc,
"requires vector type");
4913 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4914 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4915 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4916 auto permMapAttrName =
4917 TransferWriteOp::getPermutationMapAttrName(result.
name);
4921 if (shapedType.getRank() <
4924 "expected a custom permutation_map when "
4925 "rank(source) != rank(destination)");
4929 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4931 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4933 if (!inBoundsAttr) {
4942 if (hasMask.succeeded()) {
4943 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4945 maskInfo.
location,
"does not support masks with vector element type");
4948 "expected the same rank for the vector and the "
4949 "results of the permutation map");
4955 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4957 {1, 1, static_cast<int32_t>(indexInfo.size()),
4958 static_cast<int32_t>(hasMask.succeeded())}));
4959 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4966 p <<
", " << getMask();
4973 ShapedType shapedType = getShapedType();
4975 VectorType maskType = getMaskType();
4976 auto permutationMap = getPermutationMap();
4977 VectorType inferredMaskType =
4981 if (llvm::size(
getIndices()) != shapedType.getRank())
4982 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4986 if (hasBroadcastDim())
4987 return emitOpError(
"should not have broadcast dimensions");
4989 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4990 shapedType, vectorType, maskType,
4991 inferredMaskType, permutationMap, getInBounds())))
4995 [&](Twine t) {
return emitOpError(t); });
5004 Type TransferWriteOp::getExpectedMaskType() {
5011 Value TransferWriteOp::getVector() {
return getOperand(0); }
5013 return cast<VectorType>(getValueToStore().
getType());
5036 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5040 if (write.getTransferRank() == 0)
5042 auto rankedTensorType =
5043 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5045 if (!rankedTensorType)
5048 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5052 if (read.getTransferRank() == 0)
5055 if (!read.getPermutationMap().isMinorIdentity() ||
5056 !write.getPermutationMap().isMinorIdentity())
5059 if (read.getTransferRank() != write.getTransferRank())
5062 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5065 if (read.getBase().getType() != rankedTensorType)
5068 if (read.getVectorType() != write.getVectorType())
5071 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5074 auto isNotConstantZero = [](
Value v) {
5076 return !cstOp.has_value() || cstOp.value() != 0;
5078 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5079 llvm::any_of(write.getIndices(), isNotConstantZero))
5082 results.push_back(read.getBase());
5086 static bool checkSameValueWAR(vector::TransferReadOp read,
5087 vector::TransferWriteOp write) {
5088 return read.getBase() == write.getBase() &&
5089 read.getIndices() == write.getIndices() &&
5090 read.getPermutationMap() == write.getPermutationMap() &&
5091 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5108 static LogicalResult foldWAR(TransferWriteOp write,
5110 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5112 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5116 if (!checkSameValueWAR(read, write))
5118 results.push_back(read.getBase());
5122 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5124 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5126 if (succeeded(foldWAR(*
this, results)))
5138 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5142 void TransferWriteOp::getEffects(
5145 if (llvm::isa<MemRefType>(getShapedType()))
5151 if (hasPureTensorSemantics())
5186 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5188 vector::TransferWriteOp writeToModify = writeOp;
5190 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5194 writeToModify.getBaseMutable().assign(defWrite.getBase());
5199 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5200 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5204 if (!defWrite->hasOneUse())
5206 writeToModify = defWrite;
5207 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5236 struct SwapExtractSliceOfTransferWrite
5243 if (!insertOp.hasUnitStride())
5246 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5247 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5249 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5250 if (!transferOp || !transferOp->hasOneUse())
5255 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5257 "use-def chain is rank-reducing");
5261 if (!extractOp.hasZeroOffset()) {
5263 "ExtractSliceOp has non-zero offset");
5267 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5271 "TranferWriteOp has non-zero offset");
5275 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5277 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5280 for (
auto [insertSize, extractSize] :
5281 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5284 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5289 assert(transferOp.getVectorType().hasStaticShape() &&
5290 "expected vector to have a static shape");
5293 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5294 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5296 insertOp,
"TransferWriteOp may not write the full tensor.");
5302 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
5303 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5304 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5305 insertOp.getMixedStrides());
5306 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
5307 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5308 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5311 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5321 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5328 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5330 MemRefType memRefTy) {
5333 if (!vecTy.isScalable() &&
5334 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5337 if (!memRefTy.isLastDimUnitStride())
5338 return op->
emitOpError(
"most minor memref dim must have unit stride");
5346 if (failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5349 if (memRefTy.getRank() < resVecTy.getRank())
5351 "destination memref has lower rank than the result vector");
5354 Type memElemTy = memRefTy.getElementType();
5355 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5356 if (memVecTy != resVecTy)
5357 return emitOpError(
"base memref and result vector types should match");
5358 memElemTy = memVecTy.getElementType();
5361 if (resVecTy.getElementType() != memElemTy)
5362 return emitOpError(
"base and result element types should match");
5363 if (llvm::size(
getIndices()) != memRefTy.getRank())
5364 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5374 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5386 if (failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5389 if (memRefTy.getRank() < valueVecTy.getRank())
5390 return emitOpError(
"source memref has lower rank than the vector to store");
5393 Type memElemTy = memRefTy.getElementType();
5394 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5395 if (memVecTy != valueVecTy)
5397 "base memref and valueToStore vector types should match");
5398 memElemTy = memVecTy.getElementType();
5401 if (valueVecTy.getElementType() != memElemTy)
5402 return emitOpError(
"base and valueToStore element type should match");
5403 if (llvm::size(
getIndices()) != memRefTy.getRank())
5404 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5408 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5413 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5422 VectorType maskVType = getMaskVectorType();
5423 VectorType passVType = getPassThruVectorType();
5427 if (resVType.getElementType() != memType.getElementType())
5428 return emitOpError(
"base and result element type should match");
5429 if (llvm::size(
getIndices()) != memType.getRank())
5430 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5431 if (resVType.getShape() != maskVType.getShape())
5432 return emitOpError(
"expected result shape to match mask shape");
5433 if (resVType != passVType)
5434 return emitOpError(
"expected pass_thru of same type as result type");
5447 load, load.getType(), load.getBase(), load.getIndices());
5450 rewriter.
replaceOp(load, load.getPassThru());
5455 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5462 results.
add<MaskedLoadFolder>(context);
5476 VectorType maskVType = getMaskVectorType();
5480 if (valueVType.getElementType() != memType.getElementType())
5481 return emitOpError(
"base and valueToStore element type should match");
5482 if (llvm::size(
getIndices()) != memType.getRank())
5483 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5484 if (valueVType.getShape() != maskVType.getShape())
5485 return emitOpError(
"expected valueToStore shape to match mask shape");
5498 store, store.getValueToStore(), store.getBase(), store.getIndices());
5506 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5513 results.
add<MaskedStoreFolder>(context);
5516 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5526 VectorType indVType = getIndexVectorType();
5527 VectorType maskVType = getMaskVectorType();
5529 ShapedType baseType = getBaseType();
5531 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5532 return emitOpError(
"requires base to be a memref or ranked tensor type");
5534 if (resVType.getElementType() != baseType.getElementType())
5535 return emitOpError(
"base and result element type should match");
5536 if (llvm::size(
getIndices()) != baseType.getRank())
5537 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5538 if (resVType.getShape() != indVType.getShape())
5539 return emitOpError(
"expected result dim to match indices dim");
5540 if (resVType.getShape() != maskVType.getShape())
5541 return emitOpError(
"expected result dim to match mask dim");
5542 if (resVType != getPassThruVectorType())
5543 return emitOpError(
"expected pass_thru of same type as result type");
5551 Type GatherOp::getExpectedMaskType() {
5552 auto vecType = this->getIndexVectorType();
5555 vecType.getScalableDims());
5558 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5563 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5564 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5565 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5576 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5589 rewriter.
replaceOp(gather, gather.getPassThru());
5594 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5605 if (!isa<MemRefType>(op.getBase().getType()))
5608 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5612 op.getIndices(), op.getMask(),
5621 results.
add<GatherFolder, FoldContiguousGather>(context);
5629 VectorType indVType = getIndexVectorType();
5630 VectorType maskVType = getMaskVectorType();
5634 if (valueVType.getElementType() != memType.getElementType())
5635 return emitOpError(
"base and valueToStore element type should match");
5636 if (llvm::size(
getIndices()) != memType.getRank())
5637 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5638 if (valueVType.getShape() != indVType.getShape())
5639 return emitOpError(
"expected valueToStore dim to match indices dim");
5640 if (valueVType.getShape() != maskVType.getShape())
5641 return emitOpError(
"expected valueToStore dim to match mask dim");
5660 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5671 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5675 op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5683 results.
add<ScatterFolder, FoldContiguousScatter>(context);
5691 VectorType maskVType = getMaskVectorType();
5692 VectorType passVType = getPassThruVectorType();
5696 if (resVType.getElementType() != memType.getElementType())
5697 return emitOpError(
"base and result element type should match");
5698 if (llvm::size(
getIndices()) != memType.getRank())
5699 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5700 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5701 return emitOpError(
"expected result dim to match mask dim");
5702 if (resVType != passVType)
5703 return emitOpError(
"expected pass_thru of same type as result type");
5716 expand, expand.getType(), expand.getBase(), expand.getIndices());
5719 rewriter.
replaceOp(expand, expand.getPassThru());
5724 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5731 results.
add<ExpandLoadFolder>(context);
5739 VectorType maskVType = getMaskVectorType();
5743 if (valueVType.getElementType() != memType.getElementType())
5744 return emitOpError(
"base and valueToStore element type should match");
5745 if (llvm::size(
getIndices()) != memType.getRank())
5746 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5747 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5748 return emitOpError(
"expected valueToStore dim to match mask dim");
5753 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5761 compress, compress.getValueToStore(), compress.getBase(),
5762 compress.getIndices());
5770 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5777 results.
add<CompressStoreFolder>(context);
5786 setResultRanges(getResult(), argRanges.front());
5791 VectorType sourceType = getSourceVectorType();
5792 VectorType resultType = getResultVectorType();
5795 if (sourceType.getElementType() != resultType.getElementType())
5796 return emitOpError(
"has different source and result element types");
5799 int64_t sourceNElms = sourceType.getNumElements();
5800 int64_t resultNElms = resultType.getNumElements();
5801 if (sourceNElms != resultNElms) {
5802 return emitOpError() <<
"has different number of elements at source ("
5803 << sourceNElms <<
") and result (" << resultNElms
5808 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
5809 int64_t resultNScalableDims = resultType.getNumScalableDims();
5810 if (sourceNScalableDims != resultNScalableDims)
5811 return emitOpError() <<
"has different number of scalable dims at source ("
5812 << sourceNScalableDims <<
") and result ("
5813 << resultNScalableDims <<
")";
5822 static bool isOrderPreserving(TransposeOp
transpose) {
5824 VectorType sourceType =
transpose.getSourceVectorType();
5827 auto isNonScalableUnitDim = [&](int64_t dim) {
5828 return inShape[dim] == 1 && !inDimIsScalable[dim];
5830 int64_t current = 0;
5831 for (
auto p : permutation) {
5832 if (!isNonScalableUnitDim(p)) {
5844 VectorType resultType =
getType();
5847 if (getSource().
getType() == resultType)
5851 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
5852 setOperand(precedingShapeCast.getSource());
5857 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5878 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5879 if (bcastOp.getSourceType() == resultType)
5880 return bcastOp.getSource();
5884 if (
auto splatAttr =
5885 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5886 return splatAttr.reshape(
getType());
5889 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5903 static VectorType trimTrailingOneDims(VectorType oldType) {
5910 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5911 newShape = newShape.drop_back(1);
5912 newScalableDims = newScalableDims.drop_back(1);
5917 if (newShape.empty()) {
5918 newShape = oldShape.take_back();
5919 newScalableDims = oldScalableDims.take_back();
5922 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5937 class ShapeCastCreateMaskFolderTrailingOneDim final
5944 Value shapeOpSrc = shapeOp->getOperand(0);
5945 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5946 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5947 if (!createMaskOp && !constantMaskOp)
5950 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5951 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5953 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5954 if (newVecType != shapeOpResTy)
5957 auto numDimsToDrop =
5958 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5965 auto maskOperands = createMaskOp.getOperands();
5966 auto numMaskOperands = maskOperands.size();
5969 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5971 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5972 if (!constant || (constant.value() != 1))
5976 maskOperands.drop_back(numDimsToDrop);
5983 if (constantMaskOp) {
5984 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5985 auto numMaskOperands = maskDimSizes.size();
5988 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5990 if (maskDimSizes[i] != 1)
5994 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6008 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6015 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6019 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6020 bool srcIsScalar = !srcVectorType;
6028 if (srcVectorType) {
6029 if (srcVectorType.getNumElements() ==
6030 shapeCastOp.getResultVectorType().getNumElements()) {
6032 shapeCastOp, shapeCastOp.getResultVectorType(),
6033 broadcastOp.getSource());
6044 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6046 BroadcastableToResult::Success) {
6048 shapeCastOp, dstVectorType, broadcastOp.getSource());
6060 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6069 auto sourceVectorType = getSourceVectorType();
6070 auto resultVectorType = getResultVectorType();
6072 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6073 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6074 return emitOpError(
"dimension size mismatch at: ") << i;
6077 DataLayout dataLayout = DataLayout::closest(*
this);
6078 auto sourceElementBits =
6080 auto resultElementBits =
6083 if (sourceVectorType.getRank() == 0) {
6084 if (sourceElementBits != resultElementBits)
6085 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6086 "types must be equal");
6087 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6088 resultElementBits * resultVectorType.getShape().back()) {
6090 "source/result bitwidth of the minor 1-D vectors must be equal");
6102 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6103 if (getResult().
getType() == otherOp.getSource().getType())
6104 return otherOp.getSource();
6106 setOperand(otherOp.getSource());
6110 Attribute sourceConstant = adaptor.getSource();
6111 if (!sourceConstant)
6114 Type srcElemType = getSourceVectorType().getElementType();
6115 Type dstElemType = getResultVectorType().getElementType();
6117 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6118 if (floatPack.isSplat()) {
6119 auto splat = floatPack.getSplatValue<FloatAttr>();
6122 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6123 uint32_t bits =
static_cast<uint32_t
>(
6124 splat.getValue().bitcastToAPInt().getZExtValue());
6126 bits = (bits << 16) | (bits & 0xffff);
6127 APInt intBits(32, bits);
6128 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6134 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6135 if (intPack.isSplat()) {
6136 auto splat = intPack.getSplatValue<IntegerAttr>();
6138 if (llvm::isa<IntegerType>(dstElemType)) {
6143 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6144 APInt intBits = splat.getValue().zext(dstBitWidth);
6147 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6148 intBits = (intBits << srcBitWidth) | intBits;
6163 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6166 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6175 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6176 VectorType vectorType =
6180 memRefType.getMemorySpace()));
6184 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6185 if (!canonicalType.getLayout().isIdentity())
6186 return emitOpError(
"expects operand to be a memref with identity layout");
6187 if (!getResultMemRefType().getLayout().isIdentity())
6188 return emitOpError(
"expects result to be a memref with identity layout");
6189 if (getResultMemRefType().getMemorySpace() !=
6191 return emitOpError(
"expects result in same memory space");
6194 auto resultType = getResultMemRefType();
6198 "expects result and operand with same underlying scalar type: ")
6200 if (extractShape(sourceType) != extractShape(resultType))
6202 "expects concatenated result and operand shapes to be equal: ")
6213 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6216 for (
unsigned i = 0; i < permutation.size(); ++i) {
6217 transposedShape[i] = vt.getShape()[permutation[i]];
6218 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6223 transposedScalableDims));
6228 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6231 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6232 return splat.reshape(getResultVectorType());
6235 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6249 if (getSourceVectorType() == getResultVectorType() &&
6250 isOrderPreserving(*
this))
6257 VectorType vectorType = getSourceVectorType();
6258 VectorType resultType = getResultVectorType();
6259 int64_t rank = resultType.getRank();
6260 if (vectorType.getRank() != rank)
6261 return emitOpError(
"vector result rank mismatch: ") << rank;
6264 int64_t size = perm.size();
6266 return emitOpError(
"transposition length mismatch: ") << size;
6269 if (ta.value() < 0 || ta.value() >= rank)
6270 return emitOpError(
"transposition index out of range: ") << ta.value();
6271 if (seen[ta.value()])
6272 return emitOpError(
"duplicate position index: ") << ta.value();
6273 seen[ta.value()] =
true;
6274 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6275 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6280 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6281 return llvm::to_vector<4>(getResultVectorType().
getShape());
6287 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6297 for (
auto index : permutation2)
6298 result.push_back(permutation1[index]);
6303 vector::TransposeOp parentTransposeOp =
6304 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6305 if (!parentTransposeOp)
6309 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6312 transposeOp, transposeOp.getResult().getType(),
6313 parentTransposeOp.getVector(), permutation);
6325 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6330 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6336 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6342 Value transposeSrc = transpOp.getVector();
6343 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6344 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6345 if (!createMaskOp && !constantMaskOp)
6353 auto maskOperands = createMaskOp.getOperands();
6358 transpOp, transpOp.getResultVectorType(), newOperands);
6363 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6367 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6373 class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6380 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6383 if (!isOrderPreserving(transposeOp))
6386 VectorType resultType = transposeOp.getType();
6393 shapeCastOp.getSource());
6423 class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
6433 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6436 "not preceded by a broadcast");
6439 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
6440 VectorType outputType =
transpose.getResultVectorType();
6443 bool inputIsScalar = !inputType;
6444 if (inputIsScalar) {
6452 int64_t inputRank = inputType.getRank();
6453 int64_t outputRank =
transpose.getType().getRank();
6454 int64_t deltaRank = outputRank - inputRank;
6457 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6458 bool notOne = inputShape[inputIndex] != 1;
6459 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6460 bool groupEndFound = notOne || prevNotOne;
6461 if (groupEndFound) {
6462 int high = inputIndex + deltaRank;
6466 for (
int i = low; i < high; ++i) {
6467 if (permutation[i] < low || permutation[i] >= high) {
6469 transpose,
"permutation not local to group");
6483 vector::BroadcastableToResult::Success &&
6484 "not broadcastable directly to transpose output");
6495 void vector::TransposeOp::getCanonicalizationPatterns(
6497 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6498 FoldTransposeSplat, FoldTransposeBroadcast>(context);
6507 assert(
kind == ConstantMaskKind::AllTrue ||
6508 kind == ConstantMaskKind::AllFalse);
6509 build(builder, result, type,
6510 kind == ConstantMaskKind::AllTrue
6516 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6518 if (resultType.getRank() == 0) {
6519 if (getMaskDimSizes().size() != 1)
6520 return emitError(
"array attr must have length 1 for 0-D vectors");
6521 auto dim = getMaskDimSizes()[0];
6522 if (dim != 0 && dim != 1)
6523 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6528 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6530 "must specify array attr of size equal vector result rank");
6533 auto resultShape = resultType.getShape();
6534 auto resultScalableDims = resultType.getScalableDims();
6536 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6537 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6539 "array attr of size out of bounds of vector result dimension size");
6540 if (resultScalableDims[index] && maskDimSize != 0 &&
6541 maskDimSize != resultShape[index])
6543 "only supports 'none set' or 'all set' scalable dimensions");
6547 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6548 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6549 if (anyZeros && !allZeros)
6550 return emitOpError(
"expected all mask dim sizes to be zeros, "
6551 "as a result of conjunction with zero mask dim");
6555 bool ConstantMaskOp::isAllOnesMask() {
6558 if (resultType.getRank() == 0) {
6559 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6560 return getMaskDimSizes()[0] == 1;
6562 for (
const auto [resultSize, maskDimSize] :
6563 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6564 if (maskDimSize < resultSize)
6579 build(builder, result, type, operands);
6583 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6585 if (vectorType.getRank() == 0) {
6586 if (getNumOperands() != 1)
6588 "must specify exactly one operand for 0-D create_mask");
6589 }
else if (getNumOperands() !=
6590 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6592 "must specify an operand for each result vector dimension");
6628 VectorType maskType = createMaskOp.getVectorType();
6630 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6633 constexpr std::array<int64_t, 1> rankZeroShape{1};
6634 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6635 if (maskType.getRank() == 0) {
6636 maskTypeDimSizes = rankZeroShape;
6637 maskTypeDimScalableFlags = rankZeroScalableDims;
6643 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6648 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6650 constantDims.push_back(*intSize);
6654 if (vscaleMultiplier < maskTypeDimSizes[i])
6656 constantDims.push_back(*vscaleMultiplier);
6663 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6664 value = std::clamp<int64_t>(value, 0, maskDimSize);
6667 if (llvm::is_contained(constantDims, 0))
6668 constantDims.assign(constantDims.size(), 0);
6681 results.
add<CreateMaskFolder>(context);
6692 assert(maskRegionBuilder &&
6693 "builder callback for 'maskRegion' must be present");
6699 maskRegionBuilder(builder, maskableOp);
6706 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6714 build(builder, result, mask, maskableOp, maskRegionBuilder);
6735 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6742 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6756 result.
types.append(resultTypes);
6762 if (parsePassthru.succeeded()) {
6763 if (resultTypes.empty())
6766 "expects a result if passthru operand is provided");
6776 p <<
" " << getMask();
6778 p <<
", " << getPassthru();
6782 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6789 p <<
" : " << getMask().getType();
6790 if (getNumResults() > 0)
6791 p <<
" -> " << getResultTypes();
6798 MaskOp>::ensureTerminator(region, builder, loc);
6804 if (isa<vector::YieldOp>(block.
back()))
6813 MaskOp>::ensureTerminator(region, builder, loc);
6820 opBuilder.setInsertionPointToEnd(&block);
6826 Block &block = getMaskRegion().getBlocks().
front();
6828 return emitOpError(
"expects a terminator within the mask region");
6831 if (numMaskRegionOps > 2)
6832 return emitOpError(
"expects only one operation to mask");
6835 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6837 return emitOpError(
"expects a terminator within the mask region");
6839 if (terminator->getNumOperands() != getNumResults())
6841 "expects number of results to match mask region yielded values");
6844 if (numMaskRegionOps == 1)
6847 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6849 return emitOpError(
"expects a MaskableOpInterface within the mask region");
6853 return emitOpError(
"expects number of results to match maskable operation "
6854 "number of results");
6856 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
6857 return emitOpError(
"expects all the results from the MaskableOpInterface "
6858 "to match all the values returned by the terminator");
6860 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6862 "expects result type to match maskable operation result type");
6865 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6866 return emitOpError(
"multiple vector results not supported");
6869 Type expectedMaskType = maskableOp.getExpectedMaskType();
6870 if (getMask().
getType() != expectedMaskType)
6871 return emitOpError(
"expects a ")
6872 << expectedMaskType <<
" mask for the maskable operation";
6875 Value passthru = getPassthru();
6877 if (!maskableOp.supportsPassthru())
6879 "doesn't expect a passthru argument for this maskable operation");
6882 return emitOpError(
"expects result when passthru argument is provided");
6885 return emitOpError(
"expects passthru type to match result type");
6905 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6907 if (!maskOp.isEmpty() || maskOp.hasPassthru())
6910 Block *block = maskOp.getMaskBlock();
6911 auto terminator = cast<vector::YieldOp>(block->
front());
6912 if (terminator.getNumOperands() == 0) {
6918 llvm::append_range(results, terminator.getOperands());
6922 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6924 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
6932 Operation *maskableOp = getMaskableOp();
6936 llvm::append_range(results, maskableOp->
getResults());
6957 if (!maskOp.isEmpty())
6960 if (!maskOp.hasPassthru())
6963 Block *block = maskOp.getMaskBlock();
6964 auto terminator = cast<vector::YieldOp>(block->
front());
6965 assert(terminator.getNumOperands() == 1 &&
6966 "expected one result when passthru is provided");
6969 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
6970 terminator.getOperand(0), maskOp.getPassthru());
6978 results.
add<CanonializeEmptyMaskOp>(context);
6985 Block *block = getMaskBlock();
6989 return &block->
front();
6993 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
7000 VectorType srcType = getSourceType();
7001 VectorType initialType = getInitialValueType();
7003 int64_t srcRank = srcType.getRank();
7004 int64_t reductionDim = getReductionDim();
7005 if (reductionDim >= srcRank)
7006 return emitOpError(
"reduction dimension ")
7007 << reductionDim <<
" has to be less than " << srcRank;
7010 int64_t initialValueRank = initialType.getRank();
7011 if (initialValueRank != srcRank - 1)
7012 return emitOpError(
"initial value rank ")
7013 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7019 for (
int i = 0; i < srcRank; i++) {
7020 if (i != reductionDim)
7021 expectedShape.push_back(srcShape[i]);
7023 if (!llvm::equal(initialValueShapes, expectedShape)) {
7024 return emitOpError(
"incompatible input/initial value shapes");
7028 Type eltType = getDestType().getElementType();
7030 return emitOpError(
"unsupported reduction type ")
7031 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7040 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7041 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7042 StridedSliceConstantMaskFolder, TransposeFolder>(
7051 auto constOperand = adaptor.getInput();
7052 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
7061 setResultRanges(getResult(), argRanges.front());
7066 arith::FastMathFlagsAttr fastmath,
7073 case CombiningKind::ADD:
7076 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7077 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7079 llvm_unreachable(
"invalid value types for ADD reduction");
7081 case CombiningKind::AND:
7085 case CombiningKind::MAXNUMF:
7086 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7087 "expected float values");
7088 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7090 case CombiningKind::MAXIMUMF:
7091 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7092 "expected float values");
7093 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7095 case CombiningKind::MINNUMF:
7096 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7097 "expected float values");
7098 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7100 case CombiningKind::MINIMUMF:
7101 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7102 "expected float values");
7103 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7105 case CombiningKind::MAXSI:
7109 case CombiningKind::MINSI:
7113 case CombiningKind::MAXUI:
7121 case CombiningKind::MUL:
7124 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7125 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7127 llvm_unreachable(
"invalid value types for MUL reduction");
7129 case CombiningKind::OR:
7133 case CombiningKind::XOR:
7139 assert(result &&
"unknown CombiningKind");
7151 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7171 return builder.
create<MaskOp>(maskableOp->getLoc(),
7172 maskableOp->getResultTypes(), mask, maskableOp,
7189 mask, newValue, passthru);
7196 #define GET_ATTRDEF_CLASSES
7197 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7199 #define GET_OP_CLASSES
7200 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)