39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
91 ArrayAttr masks = m.getMaskDimSizes();
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
110 auto maskOperands = m.getOperands();
111 for (
Value operand : maskOperands) {
112 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
127 builder.
create<vector::YieldOp>(loc);
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(elementType);
155 VectorType vectorType) {
156 int64_t elementVectorRank = 0;
157 VectorType elementVectorType =
158 llvm::dyn_cast<VectorType>(shapedType.getElementType());
159 if (elementVectorType)
160 elementVectorRank += elementVectorType.getRank();
163 if (shapedType.getRank() == 0 &&
169 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
170 shapedType.getContext());
174 vector::TransferReadOp read) {
175 return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
176 !read.getMask() && defWrite.getIndices() == read.getIndices() &&
177 defWrite.getVectorType() == read.getVectorType() &&
178 defWrite.getPermutationMap() == read.getPermutationMap();
182 vector::TransferWriteOp priorWrite) {
183 return priorWrite.getIndices() == write.getIndices() &&
184 priorWrite.getMask() == write.getMask() &&
185 priorWrite.getVectorType() == write.getVectorType() &&
186 priorWrite.getPermutationMap() == write.getPermutationMap();
190 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
191 bool testDynamicValueUsingBounds) {
193 if (transferA.getVectorType() != transferB.getVectorType())
195 unsigned rankOffset = transferA.getLeadingShapedRank();
196 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
197 Value indexA = transferA.getIndices()[i];
198 Value indexB = transferB.getIndices()[i];
202 if (i < rankOffset) {
205 if (cstIndexA.has_value() && cstIndexB.has_value()) {
206 if (*cstIndexA != *cstIndexB)
210 if (testDynamicValueUsingBounds) {
220 if (
succeeded(testEqual) && !testEqual.value())
226 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
227 if (cstIndexA.has_value() && cstIndexB.has_value()) {
228 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
229 if (distance >= vectorDim)
233 if (testDynamicValueUsingBounds) {
244 if (
std::abs(computeDelta.value()) >= vectorDim)
254 VectorTransferOpInterface transferB,
255 bool testDynamicValueUsingBounds) {
256 if (transferA.getSource() != transferB.getSource())
259 testDynamicValueUsingBounds);
269 for (
auto [posInDim, dimSize, offsetInDim] :
270 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
272 if (posInDim < dimSize + offsetInDim)
276 posInDim = offsetInDim;
286 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
288 assert(constOp &&
"Unexpected non-constant index");
289 return constOp.value();
299 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
300 assert(foldResult.is<
Attribute>() &&
"Unexpected non-constant index");
301 return cast<IntegerAttr>(foldResult.get<
Attribute>()).getInt();
311 llvm::transform(foldResults, std::back_inserter(values),
313 if (
auto attr = foldResult.dyn_cast<
Attribute>())
316 loc, cast<IntegerAttr>(attr).getInt())
319 return foldResult.get<
Value>();
367 void VectorDialect::initialize() {
369 #define GET_ATTRDEF_LIST
370 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
375 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
378 addInterfaces<VectorInlinerInterface>();
380 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
381 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
383 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
385 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
386 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
394 return arith::ConstantOp::materialize(builder, value, type, loc);
410 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
413 CombiningKind kind) {
417 reductionDims.push_back(en.index());
418 build(builder, result, kind, source, acc,
422 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
424 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
429 std::optional<SmallVector<int64_t, 4>>
430 MultiDimReductionOp::getShapeForUnroll() {
431 return llvm::to_vector<4>(getSourceVectorType().
getShape());
437 Type inferredReturnType;
438 auto sourceScalableDims = getSourceVectorType().getScalableDims();
440 if (!llvm::any_of(getReductionDims().getValue(), [&](
Attribute attr) {
441 return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
443 targetShape.push_back(it.value());
444 scalableDims.push_back(sourceScalableDims[it.index()]);
447 if (targetShape.empty())
448 inferredReturnType = getSourceVectorType().getElementType();
451 targetShape, getSourceVectorType().
getElementType(), scalableDims);
452 if (getType() != inferredReturnType)
453 return emitOpError() <<
"destination type " << getType()
454 <<
" is incompatible with source type "
455 << getSourceVectorType();
461 Type MultiDimReductionOp::getExpectedMaskType() {
462 auto vecType = getSourceVectorType();
465 vecType.getScalableDims());
474 struct ElideUnitDimsInMultiDimReduction
478 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
481 for (
const auto &dim :
enumerate(shape)) {
482 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
490 if (reductionOp.isMasked()) {
492 rootOp = reductionOp.getMaskingOp();
493 mask = reductionOp.getMaskingOp().getMask();
495 rootOp = reductionOp;
498 Location loc = reductionOp.getLoc();
499 Value acc = reductionOp.getAcc();
501 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
503 VectorType newMaskType =
505 dstVecType.getScalableDims());
506 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
508 cast = rewriter.
create<vector::ShapeCastOp>(
509 loc, reductionOp.getDestType(), reductionOp.getSource());
515 mask = rewriter.
create<vector::ExtractOp>(loc, mask, zeroIdx);
516 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource(),
522 cast,
nullptr, mask);
529 void MultiDimReductionOp::getCanonicalizationPatterns(
531 results.
add<ElideUnitDimsInMultiDimReduction>(context);
539 CombiningKind kind,
Value vector,
540 arith::FastMathFlags fastMathFlags) {
541 build(builder, result, kind, vector,
Value(), fastMathFlags);
546 arith::FastMathFlags fastMathFlags) {
547 build(builder, result,
548 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
554 int64_t rank = getSourceVectorType().getRank();
556 return emitOpError(
"unsupported reduction rank: ") << rank;
559 Type eltType = getDest().getType();
561 return emitOpError(
"unsupported reduction type '")
562 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
571 Type ReductionOp::getExpectedMaskType() {
572 auto vecType = getSourceVectorType();
575 vecType.getScalableDims());
582 case arith::AtomicRMWKind::addf:
583 case arith::AtomicRMWKind::addi:
584 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
585 CombiningKind::ADD, vector);
586 case arith::AtomicRMWKind::mulf:
587 case arith::AtomicRMWKind::muli:
588 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
589 CombiningKind::MUL, vector);
590 case arith::AtomicRMWKind::minimumf:
591 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
592 CombiningKind::MINIMUMF, vector);
593 case arith::AtomicRMWKind::mins:
594 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
595 CombiningKind::MINSI, vector);
596 case arith::AtomicRMWKind::minu:
597 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
599 case arith::AtomicRMWKind::maximumf:
600 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
601 CombiningKind::MAXIMUMF, vector);
602 case arith::AtomicRMWKind::maxs:
603 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
604 CombiningKind::MAXSI, vector);
605 case arith::AtomicRMWKind::maxu:
606 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
607 CombiningKind::MAXUI, vector);
608 case arith::AtomicRMWKind::andi:
609 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
610 CombiningKind::AND, vector);
611 case arith::AtomicRMWKind::ori:
612 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
613 CombiningKind::OR, vector);
622 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
623 return llvm::to_vector<4>(getSourceVectorType().
getShape());
635 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
638 if (maskableOp.isMasked()) {
640 rootOp = maskableOp.getMaskingOp();
641 mask = maskableOp.getMaskingOp().getMask();
643 rootOp = reductionOp;
646 auto vectorType = reductionOp.getSourceVectorType();
647 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
650 Location loc = reductionOp.getLoc();
652 if (vectorType.getRank() == 0) {
654 mask = rewriter.
create<ExtractElementOp>(loc, mask);
655 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
658 mask = rewriter.
create<ExtractOp>(loc, mask, 0);
659 result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector(), 0);
662 if (
Value acc = reductionOp.getAcc())
665 reductionOp.getFastmathAttr(), mask);
675 results.
add<ElideSingleElementReduction>(context);
689 getIndexingMapsAttrName(result.
name),
693 getIteratorTypesAttrName(result.
name),
696 return IteratorTypeAttr::get(builder.getContext(), t);
702 ArrayAttr indexingMaps,
703 ArrayAttr iteratorTypes) {
704 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
705 ContractionOp::getDefaultKind());
710 ArrayAttr indexingMaps,
711 ArrayAttr iteratorTypes, CombiningKind kind) {
728 DictionaryAttr dictAttr;
743 dictAttr.getValue().end());
749 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
754 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
755 auto maybeIteratorType = symbolizeIteratorType(s);
756 if (!maybeIteratorType.has_value())
757 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
759 iteratorTypeAttrs.push_back(
767 getKindAttrName(result.
name),
769 ContractionOp::getDefaultKind()));
771 if (masksInfo.empty())
773 if (masksInfo.size() != 2)
775 "expected zero or exactly 2 vector mask operands");
776 auto lhsType = llvm::cast<VectorType>(types[0]);
777 auto rhsType = llvm::cast<VectorType>(types[1]);
779 std::array<VectorType, 2> maskTypes = {
789 auto attrNames = getTraitAttrNames();
791 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
793 for (
auto attr : (*this)->getAttrs()) {
794 if (attr.getName() == getIteratorTypesAttrName()) {
796 llvm::cast<ArrayAttr>(attr.getValue())
797 .getAsValueRange<IteratorTypeAttr, IteratorType>();
803 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
807 attrs.emplace_back(getIteratorTypesAttrName(),
809 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
810 attrs.push_back(attr);
814 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
815 p << getRhs() <<
", " << getAcc();
818 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
823 const std::vector<std::pair<int64_t, int64_t>> &map) {
824 for (
auto &dimPair : map) {
825 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
826 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
827 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
834 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
836 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
837 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
840 for (
auto &dimPair : contractingDimMap) {
841 lhsContractingDimSet.insert(dimPair.first);
842 rhsContractingDimSet.insert(dimPair.second);
845 for (
auto &dimPair : batchDimMap)
846 rhsBatchDimSet.insert(dimPair.second);
850 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
851 if (lhsContractingDimSet.count(i) > 0)
853 expectedResultDims.push_back(lhsType.getDimSize(i));
857 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
858 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
860 expectedResultDims.push_back(rhsType.getDimSize(i));
864 if (expectedResultDims.empty()) {
866 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
867 return op.
emitOpError(
"invalid accumulator/result vector shape");
870 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
871 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
872 if (!resVectorType || !accVectorType)
873 return op.
emitOpError(
"invalid accumulator/result vector shape");
879 AffineMap lhsMap = op.getIndexingMapsArray()[0];
880 AffineMap rhsMap = op.getIndexingMapsArray()[1];
883 "expected all dimensions to be either a LHS or a RHS dimension");
886 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
887 VectorType v = pair.first;
888 auto map = pair.second;
889 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
890 unsigned pos = map.getDimPosition(idx);
895 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
896 return op.
emitOpError(
"expected all dimensions to get an extent as "
897 "either a LHS or a RHS dimension");
899 AffineMap resMap = op.getIndexingMapsArray()[2];
906 [](
AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
907 "expected constant extent along all dimensions.");
909 auto expectedShape = llvm::to_vector<4>(
911 return cast<AffineConstantExpr>(e).getValue();
915 resVectorType.getScalableDims());
916 if (resVectorType != expected || accVectorType != expected)
918 "invalid accumulator/result vector shape, expected: ")
925 VectorType lhsType = getLhsType();
926 VectorType rhsType = getRhsType();
927 Type accType = getAccType();
928 Type resType = getResultType();
930 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
931 if (!lhsType.getElementType().isSignlessInteger())
932 return emitOpError(
"only supports signless integer types");
936 if (getIndexingMapsArray().size() != 3)
937 return emitOpError(
"expected an indexing map for each vector operand");
942 unsigned numIterators = getIteratorTypes().getValue().size();
944 auto index = it.index();
945 auto map = it.value();
946 if (map.getNumSymbols() != 0)
947 return emitOpError(
"expected indexing map ")
948 << index <<
" to have no symbols";
949 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
950 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
953 if (map.getNumDims() != numIterators)
954 return emitOpError(
"expected indexing map ")
955 << index <<
" to have " << numIterators <<
" number of inputs";
956 if (map.getNumResults() != rank)
957 return emitOpError(
"expected indexing map ")
958 << index <<
" to have " << rank <<
" number of outputs";
959 if (!map.isProjectedPermutation())
960 return emitOpError(
"expected indexing map ")
961 << index <<
" to be a projected permutation of its inputs";
964 auto contractingDimMap = getContractingDimMap();
965 auto batchDimMap = getBatchDimMap();
968 if (contractingDimMap.empty())
969 return emitOpError(
"expected at least one contracting dimension pair");
973 return emitOpError(
"invalid contracting dimension map");
977 return emitOpError(
"invalid batch dimension map");
981 contractingDimMap, batchDimMap)))
985 auto vectorType = llvm::dyn_cast<VectorType>(resType);
986 auto elementType = vectorType ? vectorType.getElementType() : resType;
988 return emitOpError(
"unsupported contraction type");
997 Type ContractionOp::getExpectedMaskType() {
998 auto indexingMaps = this->getIndexingMapsArray();
1001 VectorType lhsType = this->getLhsType();
1002 VectorType rhsType = this->getRhsType();
1004 unsigned numVecDims = lhsIdxMap.
getNumDims();
1013 lhsType.getScalableDims()[dimIdx];
1018 rhsType.getScalableDims()[dimIdx];
1021 assert(!ShapedType::isDynamicShape(maskShape) &&
1022 "Mask shape couldn't be computed");
1026 maskShapeScalableDims);
1031 getIteratorTypesAttrName(), getKindAttrName()};
1041 static std::vector<std::pair<int64_t, int64_t>>
1043 IteratorType targetIteratorType,
MLIRContext *context) {
1044 std::vector<std::pair<int64_t, int64_t>> dimMap;
1046 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1047 if (iteratorType != targetIteratorType)
1053 if (lhsDim >= 0 && rhsDim >= 0)
1054 dimMap.emplace_back(lhsDim, rhsDim);
1059 void ContractionOp::getIterationBounds(
1061 auto lhsShape = getLhsType().getShape();
1062 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1068 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1069 if (iteratorType == IteratorType::reduction) {
1071 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1072 assert(lhsDimIndex >= 0);
1073 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1077 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1078 assert(resDimIndex >= 0);
1079 assert(resVectorType !=
nullptr);
1080 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1084 void ContractionOp::getIterationIndexMap(
1086 unsigned numMaps = getIndexingMapsArray().size();
1087 iterationIndexMap.resize(numMaps);
1089 auto index = it.index();
1090 auto map = it.value();
1091 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1092 auto dim = cast<AffineDimExpr>(map.getResult(i));
1093 iterationIndexMap[index][dim.getPosition()] = i;
1098 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1100 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1104 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1106 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1110 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1112 getIterationBounds(shape);
1134 template <
typename AddOpType>
1140 auto canonicalize = [&](
Value maybeContraction,
1141 Value otherOperand) -> vector::ContractionOp {
1142 vector::ContractionOp contractionOp =
1143 dyn_cast_or_null<vector::ContractionOp>(
1146 return vector::ContractionOp();
1147 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1148 contractionOp.getAcc().getDefiningOp())) {
1149 if (maybeZero.getValue() ==
1150 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1152 bvm.
map(contractionOp.getAcc(), otherOperand);
1153 auto newContraction =
1154 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1155 rewriter.
replaceOp(addOp, newContraction.getResult());
1156 return newContraction;
1159 return vector::ContractionOp();
1162 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1163 vector::ContractionOp
contract = canonicalize(a, b);
1182 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1186 VectorType vectorType = getSourceVectorType();
1187 if (vectorType.getRank() == 0) {
1189 return emitOpError(
"expected position to be empty with 0-D vector");
1192 if (vectorType.getRank() != 1)
1193 return emitOpError(
"unexpected >1 vector rank");
1195 return emitOpError(
"expected position for 1-D vector");
1199 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1201 if (!adaptor.getPosition())
1205 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1206 return splat.getInput();
1209 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1213 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1214 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1218 auto srcElements = src.getValues<
Attribute>();
1220 uint64_t posIdx = pos.getInt();
1221 if (posIdx >= srcElements.size())
1224 return srcElements[posIdx];
1232 Value source, int64_t position) {
1252 build(builder, result, source, dynamicPos,
1257 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1258 ExtractOp::Adaptor adaptor,
1260 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1261 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1262 vectorType.getRank()) {
1263 inferredReturnTypes.push_back(vectorType.getElementType());
1265 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1266 vectorType.getRank());
1268 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1269 vectorType.getScalableDims().drop_front(n)));
1277 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1278 return vectorType && vectorType.getShape().equals({1}) &&
1279 vectorType.getElementType() == r.front();
1281 if (l.size() == 1 && r.size() == 1 &&
1282 (isCompatible(l, r) || isCompatible(r, l)))
1289 auto dynamicMarkersCount =
1290 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1291 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1293 "mismatch between dynamic and static positions (kDynamic marker but no "
1294 "corresponding dynamic position) -- this can only happen due to an "
1295 "incorrect fold/rewrite");
1296 auto position = getMixedPosition();
1297 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1299 "expected position attribute of rank no greater than vector rank");
1302 int64_t constIdx = cast<IntegerAttr>(pos.get<
Attribute>()).getInt();
1303 if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1304 return emitOpError(
"expected position attribute #")
1306 <<
" to be a non-negative integer smaller than the "
1307 "corresponding vector dimension";
1314 template <
typename IntType>
1316 return llvm::to_vector<4>(llvm::map_range(
1317 arrayAttr.getAsRange<IntegerAttr>(),
1318 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1324 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1328 if (extractOp.hasDynamicPosition())
1332 ExtractOp currentOp = extractOp;
1334 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1335 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1338 if (currentOp.hasDynamicPosition())
1341 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1343 extractOp.setOperand(0, currentOp.getVector());
1346 std::reverse(globalPosition.begin(), globalPosition.end());
1347 extractOp.setStaticPosition(globalPosition);
1359 class ExtractFromInsertTransposeChainState {
1361 ExtractFromInsertTransposeChainState(ExtractOp e);
1370 template <
typename ContainerA,
typename ContainerB>
1371 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1372 return a.size() <= b.size() &&
1373 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1380 template <
typename ContainerA,
typename ContainerB>
1381 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1382 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1383 if (elemA < 0 || elemB < 0)
1398 void updateStateForNextIteration(
Value v) {
1428 Value tryToFoldExtractOpInPlace(
Value source);
1430 ExtractOp extractOp;
1432 int64_t extractedRank;
1434 InsertOp nextInsertOp;
1435 TransposeOp nextTransposeOp;
1450 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1452 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1453 extractedRank(extractOp.getNumIndices()) {
1454 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1455 sentinels.reserve(vectorRank - extractedRank);
1456 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1457 sentinels.push_back(-(i + 1));
1459 extractOp.getStaticPosition().end());
1465 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1467 if (extractOp.hasDynamicPosition())
1470 if (!nextTransposeOp)
1473 nextTransposeOp.getPermutation(), extractOp.getContext()));
1480 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1483 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1490 res = nextInsertOp.getSource();
1499 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1501 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1514 res = nextInsertOp.getSource();
1522 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1525 if (extractOp.hasDynamicPosition())
1529 bool nothingToFold = (source == extractOp.getVector());
1530 if (nothingToFold || !canFold())
1535 extractOp.setStaticPosition(
1537 extractOp.getVectorMutable().assign(source);
1538 return extractOp.getResult();
1542 Value ExtractFromInsertTransposeChainState::fold() {
1544 if (extractOp.hasDynamicPosition())
1547 Value valueToExtractFrom = extractOp.getVector();
1548 updateStateForNextIteration(valueToExtractFrom);
1549 while (nextInsertOp || nextTransposeOp) {
1553 valueToExtractFrom = nextTransposeOp.getVector();
1554 updateStateForNextIteration(valueToExtractFrom);
1560 if (
succeeded(handleInsertOpWithMatchingPos(result)))
1565 if (
succeeded(handleInsertOpWithPrefixPos(result)))
1566 return tryToFoldExtractOpInPlace(result);
1576 valueToExtractFrom = nextInsertOp.getDest();
1577 updateStateForNextIteration(valueToExtractFrom);
1580 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1585 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1586 auto vecType = dyn_cast<VectorType>(type);
1587 return vecType && vecType.getRank() == 0;
1597 if (extractOp.hasDynamicPosition())
1600 Operation *defOp = extractOp.getVector().getDefiningOp();
1601 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1610 if (extractOp.getType() == source.
getType())
1612 auto getRank = [](
Type type) {
1613 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1618 unsigned broadcastSrcRank = getRank(source.
getType());
1619 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1622 unsigned extractResultRank = getRank(extractOp.getType());
1623 if (extractResultRank >= broadcastSrcRank)
1626 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1627 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1628 if (extractVecType && broadcastVecType &&
1629 extractVecType.getShape() !=
1630 broadcastVecType.getShape().take_back(extractResultRank))
1633 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1634 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1640 broadcastOp.computeBroadcastedUnitDims();
1642 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1643 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1644 if (broadcastedUnitDims.contains(i))
1648 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1649 extractPos.erase(extractPos.begin(),
1650 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1653 extractOp.setOperand(0, source);
1654 extractOp.setStaticPosition(extractPos);
1655 return extractOp.getResult();
1661 if (extractOp.hasDynamicPosition())
1664 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1674 auto getDimReverse = [](VectorType type, int64_t n) {
1675 return type.getShape().take_back(n + 1).front();
1677 int64_t destinationRank =
1678 llvm::isa<VectorType>(extractOp.getType())
1679 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1681 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1683 if (destinationRank > 0) {
1684 auto destinationType =
1685 llvm::cast<VectorType>(extractOp.getResult().getType());
1686 for (int64_t i = 0; i < destinationRank; i++) {
1690 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1691 getDimReverse(destinationType, i))
1698 std::reverse(extractedPos.begin(), extractedPos.end());
1701 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1702 strides.push_back(stride);
1704 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1707 int64_t position =
linearize(extractedPos, strides);
1711 int64_t numDimension =
1712 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1714 for (int64_t i = 0; i < numDimension; i++) {
1715 newStrides.push_back(stride);
1717 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1719 std::reverse(newStrides.begin(), newStrides.end());
1723 extractOp.setStaticPosition(newPosition);
1724 extractOp.setOperand(0, shapeCastOp.getSource());
1725 return extractOp.getResult();
1731 if (extractOp.hasDynamicPosition())
1734 auto extractStridedSliceOp =
1735 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1736 if (!extractStridedSliceOp)
1745 if (extractStridedSliceOp.hasNonUnitStrides())
1750 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1751 while (!sliceOffsets.empty()) {
1752 size_t lastOffset = sliceOffsets.size() - 1;
1753 if (sliceOffsets.back() != 0 ||
1754 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1755 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1757 sliceOffsets.pop_back();
1759 unsigned destinationRank = 0;
1760 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1761 destinationRank = vecType.getRank();
1764 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1765 sliceOffsets.size())
1769 assert(extractedPos.size() >= sliceOffsets.size());
1770 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1771 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1772 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1776 extractOp.setStaticPosition(extractedPos);
1777 return extractOp.getResult();
1783 if (extractOp.hasDynamicPosition())
1786 int64_t destinationRank =
1787 llvm::isa<VectorType>(extractOp.getType())
1788 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1790 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1800 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1801 insertOp.getSourceVectorType().getRank();
1802 if (destinationRank > insertOp.getSourceVectorType().getRank())
1804 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1807 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1808 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1811 bool disjoint =
false;
1813 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1814 int64_t start = insertOffsets[dim];
1816 (dim < insertRankDiff)
1818 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1819 int64_t end = start + size;
1820 int64_t offset = extractOffsets[dim];
1822 if (start <= offset && offset < end) {
1823 if (dim >= insertRankDiff)
1824 offsetDiffs.push_back(offset - start);
1834 int64_t srcRankDiff =
1835 insertOp.getSourceVectorType().getRank() - destinationRank;
1836 for (int64_t i = 0; i < destinationRank; i++) {
1837 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1838 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1842 extractOp.getVectorMutable().assign(insertOp.getSource());
1845 extractOp.setStaticPosition(offsetDiffs);
1846 return extractOp.getResult();
1850 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1856 if (getNumIndices() == 0)
1860 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
1882 Operation *defOp = extractOp.getVector().getDefiningOp();
1883 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1887 if (extractOp.getType() == source.
getType())
1889 auto getRank = [](
Type type) {
1890 return llvm::isa<VectorType>(type)
1891 ? llvm::cast<VectorType>(type).getRank()
1894 unsigned broadcastSrcRank = getRank(source.
getType());
1895 unsigned extractResultRank = getRank(extractOp.getType());
1899 if (extractResultRank < broadcastSrcRank)
1903 if (extractResultRank == 0) {
1904 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
1909 extractOp, extractOp.getType(), source);
1915 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
1923 Value sourceVector = extractOp.getVector();
1927 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1930 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
1931 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1939 class ExtractOpNonSplatConstantFolder final
1947 if (extractOp.hasDynamicPosition())
1952 Value sourceVector = extractOp.getVector();
1957 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
1958 if (vecTy.isScalable())
1962 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
1963 if (!dense || dense.isSplat())
1969 copy(extractOp.getStaticPosition(), completePositions.begin());
1970 int64_t elemBeginPosition =
1972 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
1975 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
1977 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
1980 newAttr = *denseValuesBegin;
1996 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2000 VectorType extractedMaskType =
2001 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2003 if (!extractedMaskType)
2006 auto maskOperands = createMaskOp.getOperands();
2008 VectorType maskType = createMaskOp.getVectorType();
2010 bool containsUnknownDims =
false;
2013 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2015 int64_t pos = extractOpPos[dimIdx];
2016 Value operand = maskOperands[dimIdx];
2017 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2020 containsUnknownDims =
true;
2024 int64_t createMaskBound =
2025 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2027 if (pos != ShapedType::kDynamic) {
2030 allFalse |= pos >= createMaskBound;
2031 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2035 containsUnknownDims =
true;
2042 }
else if (!containsUnknownDims) {
2044 extractOp, extractedMaskType,
2045 maskOperands.drop_front(extractOpPos.size()));
2055 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2057 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2061 VectorType sourceType = castOp.getSourceVectorType();
2062 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2066 if (sourceType.getNumElements() != targetType.getNumElements())
2070 castOp.getSource());
2078 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2079 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2080 results.
add(foldExtractFromShapeCastToShapeCast);
2085 for (
auto attr : arrayAttr)
2086 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2093 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2106 int64_t rankDiff = dstShape.size() - srcShape.size();
2107 int64_t dstDim = rankDiff;
2109 for (
auto [s1, s2] :
2110 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2112 assert(s1 == 1 &&
"expected dim-1 broadcasting");
2122 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2141 Value BroadcastOp::createOrFoldBroadcastOp(
2144 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2148 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2149 if (broadcastedDims.contains(i))
2151 checkShape.push_back(dstShape[i]);
2153 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2154 "ill-formed broadcastedDims contains values not confined to "
2159 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2163 if (!srcVectorType) {
2164 assert(checkShape.empty() &&
2165 "ill-formed createOrFoldBroadcastOp arguments");
2166 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2169 assert(srcVectorType.getShape().equals(checkShape) &&
2170 "ill-formed createOrFoldBroadcastOp arguments");
2181 broadcastShape.reserve(dstShape.size());
2197 int64_t nextSrcShapeDim = broadcastedDims.size();
2198 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2199 if (broadcastedDims.contains(i)) {
2204 broadcastShape.push_back(dstShape[i]);
2205 permutation[i] = broadcastShape.size() - 1;
2211 permutation[i] = nextSrcShapeDim++;
2215 llvm::append_range(broadcastShape, srcVectorType.getShape());
2220 "unexpected dim-1 broadcast");
2222 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2224 vector::BroadcastableToResult::Success &&
2225 "must be broadcastable");
2229 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2230 if (permutation[i] != i)
2231 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2238 std::pair<int, int> *mismatchingDims) {
2242 return BroadcastableToResult::Success;
2244 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2246 return BroadcastableToResult::SourceTypeNotAVector;
2248 int64_t srcRank = srcVectorType.getRank();
2249 int64_t dstRank = dstVectorType.getRank();
2250 if (srcRank > dstRank)
2251 return BroadcastableToResult::SourceRankHigher;
2254 int64_t lead = dstRank - srcRank;
2255 for (int64_t r = 0; r < srcRank; ++r) {
2256 int64_t srcDim = srcVectorType.getDimSize(r);
2257 int64_t dstDim = dstVectorType.getDimSize(lead + r);
2258 if (srcDim != 1 && srcDim != dstDim) {
2259 if (mismatchingDims) {
2260 mismatchingDims->first = srcDim;
2261 mismatchingDims->second = dstDim;
2263 return BroadcastableToResult::DimensionMismatch;
2267 return BroadcastableToResult::Success;
2271 std::pair<int, int> mismatchingDims;
2273 getSourceType(), getResultVectorType(), &mismatchingDims);
2274 if (res == BroadcastableToResult::Success)
2276 if (res == BroadcastableToResult::SourceRankHigher)
2277 return emitOpError(
"source rank higher than destination rank");
2278 if (res == BroadcastableToResult::DimensionMismatch)
2279 return emitOpError(
"dimension mismatch (")
2280 << mismatchingDims.first <<
" vs. " << mismatchingDims.second <<
")";
2281 if (res == BroadcastableToResult::SourceTypeNotAVector)
2282 return emitOpError(
"source type is not a vector");
2283 llvm_unreachable(
"unexpected vector.broadcast op error");
2287 if (getSourceType() == getResultVectorType())
2289 if (!adaptor.getSource())
2291 auto vectorType = getResultVectorType();
2292 if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2294 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2307 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2311 broadcastOp.getResultVectorType(),
2312 srcBroadcast.getSource());
2322 results.
add<BroadcastFolder>(context);
2335 VectorType resultType = getResultVectorType();
2336 VectorType v1Type = getV1VectorType();
2337 VectorType v2Type = getV2VectorType();
2339 int64_t resRank = resultType.getRank();
2340 int64_t v1Rank = v1Type.getRank();
2341 int64_t v2Rank = v2Type.getRank();
2342 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2343 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2344 if (!wellFormed0DCase && !wellFormedNDCase)
2345 return emitOpError(
"rank mismatch");
2348 for (int64_t r = 1; r < v1Rank; ++r) {
2349 int64_t resDim = resultType.getDimSize(r);
2350 int64_t v1Dim = v1Type.getDimSize(r);
2351 int64_t v2Dim = v2Type.getDimSize(r);
2352 if (resDim != v1Dim || v1Dim != v2Dim)
2353 return emitOpError(
"dimension mismatch");
2356 auto maskAttr = getMask().getValue();
2357 int64_t maskLength = maskAttr.size();
2358 if (maskLength <= 0)
2359 return emitOpError(
"invalid mask length");
2360 if (maskLength != resultType.getDimSize(0))
2361 return emitOpError(
"mask length mismatch");
2363 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2364 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2366 auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2367 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2368 return emitOpError(
"mask index #") << (en.index() + 1) <<
" out of range";
2374 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2375 ShuffleOp::Adaptor adaptor,
2377 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2378 auto v1Rank = v1Type.getRank();
2382 shape.reserve(v1Rank);
2383 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2386 llvm::append_range(shape, v1Type.getShape().drop_front());
2387 inferredReturnTypes.push_back(
2393 uint64_t expected = begin;
2394 return idxArr.size() == width &&
2395 llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2396 [&expected](
auto attr) {
2397 return attr.getZExtValue() == expected++;
2401 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2402 VectorType v1Type = getV1VectorType();
2405 if (v1Type.getRank() == 0)
2409 if (!v1Type.isScalable() &&
2413 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2415 getV2VectorType().getDimSize(0)))
2418 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2423 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2426 if (lhsType.getRank() != 1)
2428 int64_t lhsSize = lhsType.getDimSize(0);
2431 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<
Attribute>();
2432 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<
Attribute>();
2433 for (
const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2434 int64_t i = index.getZExtValue();
2436 results.push_back(rhsElements[i - lhsSize]);
2438 results.push_back(lhsElements[i]);
2454 VectorType v1VectorType = shuffleOp.getV1VectorType();
2455 ArrayAttr mask = shuffleOp.getMask();
2456 if (v1VectorType.getRank() > 0)
2458 if (mask.size() != 1)
2461 if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2478 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2479 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2481 if (!v1Splat || !v2Splat)
2484 if (v1Splat.getInput() != v2Splat.getInput())
2500 VectorType resultType = op.getResultVectorType();
2501 if (resultType.isScalable())
2503 op,
"ShuffleOp can't represent a scalable interleave");
2505 if (resultType.getRank() != 1)
2507 op,
"ShuffleOp can't represent an n-D interleave");
2509 VectorType sourceType = op.getV1VectorType();
2510 if (sourceType != op.getV2VectorType() ||
2511 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2513 op,
"ShuffleOp types don't match an interleave");
2516 ArrayAttr shuffleMask = op.getMask();
2517 int64_t resultVectorSize = resultType.getNumElements();
2518 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2519 int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2520 int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2521 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2523 "ShuffleOp mask not interleaving");
2535 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2545 build(builder, result, source, dest, {});
2549 auto dstVectorType = getDestVectorType();
2550 if (dstVectorType.getRank() == 0) {
2552 return emitOpError(
"expected position to be empty with 0-D vector");
2555 if (dstVectorType.getRank() != 1)
2556 return emitOpError(
"unexpected >1 vector rank");
2558 return emitOpError(
"expected position for 1-D vector");
2562 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2564 if (!adaptor.getPosition())
2567 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2568 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2569 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2570 if (!src || !dst || !pos)
2576 auto dstElements = dst.getValues<
Attribute>();
2580 uint64_t posIdx = pos.getInt();
2581 if (posIdx >= results.size())
2583 results[posIdx] = src;
2593 Value source,
Value dest, int64_t position) {
2606 posVals.reserve(position.size());
2607 llvm::transform(position, std::back_inserter(posVals),
2609 build(builder, result, source, dest, posVals);
2618 build(builder, result, source, dest, dynamicPos,
2624 auto destVectorType = getDestVectorType();
2625 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2627 "expected position attribute of rank no greater than dest vector rank");
2628 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2629 if (srcVectorType &&
2630 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2631 static_cast<unsigned>(destVectorType.getRank())))
2632 return emitOpError(
"expected position attribute rank + source rank to "
2633 "match dest vector rank");
2634 if (!srcVectorType &&
2635 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2637 "expected position attribute rank to match the dest vector rank");
2639 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2640 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2641 if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2642 return emitOpError(
"expected position attribute #")
2644 <<
" to be a non-negative integer smaller than the "
2646 "dest vector dimension";
2663 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2664 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2665 srcVecType.getNumElements())
2668 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2680 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2681 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2683 if (!srcSplat || !dstSplat)
2686 if (srcSplat.getInput() != dstSplat.getInput())
2701 static constexpr int64_t vectorSizeFoldThreshold = 256;
2706 if (op.hasDynamicPosition())
2716 VectorType destTy = destVector.getType();
2717 if (destTy.isScalable())
2721 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2722 !destVector.hasOneUse())
2725 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
2727 Value sourceValue = op.getSource();
2735 copy(op.getStaticPosition(), completePositions.begin());
2736 int64_t insertBeginPosition =
2740 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2741 llvm::append_range(insertedValues, denseSource.getValues<
Attribute>());
2743 insertedValues.push_back(sourceCst);
2745 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
2746 copy(insertedValues, allValues.begin() + insertBeginPosition);
2758 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2759 InsertOpConstantFolder>(context);
2765 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2766 if (getNumIndices() == 0)
2790 template <
typename OpType>
2792 ArrayAttr arrayAttr,
2794 StringRef attrName) {
2795 if (arrayAttr.size() > shape.size())
2797 << attrName <<
" attribute of rank no greater than vector rank";
2804 template <
typename OpType>
2807 int64_t
max, StringRef attrName,
2808 bool halfOpen =
true) {
2809 for (
auto attr : arrayAttr) {
2810 auto val = llvm::cast<IntegerAttr>(attr).getInt();
2814 if (val < min || val >= upper)
2815 return op.
emitOpError(
"expected ") << attrName <<
" to be confined to ["
2816 <<
min <<
", " << upper <<
")";
2824 template <
typename OpType>
2828 bool halfOpen =
true, int64_t
min = 0) {
2829 for (
auto [index, attrDimPair] :
2831 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
2832 int64_t
max = std::get<1>(attrDimPair);
2835 if (val < min || val >=
max)
2837 << attrName <<
" dimension " << index <<
" to be confined to ["
2838 <<
min <<
", " <<
max <<
")";
2848 template <
typename OpType>
2850 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2852 bool halfOpen =
true, int64_t
min = 1) {
2853 assert(arrayAttr1.size() <= shape.size());
2854 assert(arrayAttr2.size() <= shape.size());
2855 for (
auto [index, it] :
2857 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
2858 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
2859 int64_t
max = std::get<2>(it);
2862 if (val1 + val2 < 0 || val1 + val2 >=
max)
2864 << attrName1 <<
", " << attrName2 <<
") dimension " << index
2865 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
2872 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
2879 auto sourceVectorType = getSourceVectorType();
2880 auto destVectorType = getDestVectorType();
2881 auto offsets = getOffsetsAttr();
2882 auto strides = getStridesAttr();
2883 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
2885 "expected offsets of same size as destination vector rank");
2886 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
2887 return emitOpError(
"expected strides of same size as source vector rank");
2888 if (sourceVectorType.getRank() > destVectorType.getRank())
2890 "expected source rank to be no greater than destination rank");
2892 auto sourceShape = sourceVectorType.getShape();
2893 auto destShape = destVectorType.getShape();
2895 destShape.size() - sourceShape.size(), 0);
2896 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2897 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2898 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2907 offName,
"source vector shape",
2911 unsigned rankDiff = destShape.size() - sourceShape.size();
2912 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
2913 if (sourceVectorType.getScalableDims()[idx] !=
2914 destVectorType.getScalableDims()[idx + rankDiff]) {
2915 return emitOpError(
"mismatching scalable flags (at source vector idx=")
2918 if (sourceVectorType.getScalableDims()[idx]) {
2919 auto sourceSize = sourceShape[idx];
2920 auto destSize = destShape[idx + rankDiff];
2921 if (sourceSize != destSize) {
2922 return emitOpError(
"expected size at idx=")
2924 << (
" to match the corresponding base size from the input "
2926 << sourceSize << (
" vs ") << destSize << (
")");
2937 class FoldInsertStridedSliceSplat final
2945 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2947 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2949 if (!srcSplatOp || !destSplatOp)
2952 if (srcSplatOp.getInput() != destSplatOp.getInput())
2955 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2962 class FoldInsertStridedSliceOfExtract final
2969 auto extractStridedSliceOp =
2970 insertStridedSliceOp.getSource()
2971 .getDefiningOp<vector::ExtractStridedSliceOp>();
2973 if (!extractStridedSliceOp)
2976 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2980 if (extractStridedSliceOp.getStrides() !=
2981 insertStridedSliceOp.getStrides() ||
2982 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2985 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2992 class InsertStridedSliceConstantFolder final
2999 static constexpr int64_t vectorSizeFoldThreshold = 256;
3010 VectorType destTy = destVector.getType();
3011 if (destTy.isScalable())
3015 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3016 !destVector.hasOneUse())
3019 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3027 if (op.hasNonUnitStrides())
3030 VectorType sliceVecTy = sourceValue.getType();
3032 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3042 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3043 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3044 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3047 currDestPosition.begin() + rankDifference, currDestPosition.end());
3051 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3052 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3053 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3054 "Invalid slice element");
3055 newValues[linearizedPosition] = *sliceValuesIt;
3068 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3070 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3071 InsertStridedSliceConstantFolder>(context);
3074 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3075 if (getSourceVectorType() == getDestVectorType())
3092 p <<
" " << getLhs() <<
", " << getRhs();
3094 p <<
", " << getAcc();
3097 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3108 if (operandsInfo.size() < 2)
3110 "expected at least 2 operands");
3111 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3112 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3115 "expected vector type for operand #1");
3120 vRHS.getScalableDims()[0]};
3122 vLHS.getElementType(), scalableDimsRes);
3126 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3132 OuterProductOp::getKindAttrName(result.
name),
3134 OuterProductOp::getDefaultKind()));
3140 (operandsInfo.size() > 2 &&
3146 Type tRHS = getOperandTypeRHS();
3147 VectorType vLHS = getOperandVectorTypeLHS(),
3148 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3149 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3151 if (vLHS.getRank() != 1)
3152 return emitOpError(
"expected 1-d vector for operand #1");
3156 if (vRHS.getRank() != 1)
3157 return emitOpError(
"expected 1-d vector for operand #2");
3158 if (vRES.getRank() != 2)
3159 return emitOpError(
"expected 2-d vector result");
3160 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3161 return emitOpError(
"expected #1 operand dim to match result dim #1");
3162 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3163 return emitOpError(
"expected #2 operand dim to match result dim #2");
3164 if (vLHS.isScalable() && !vRHS.isScalable()) {
3168 "expected either both or only #2 operand dim to be scalable");
3172 if (vRES.getRank() != 1)
3173 return emitOpError(
"expected 1-d vector result");
3174 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3175 return emitOpError(
"expected #1 operand dim to match result dim #1");
3178 if (vACC && vACC != vRES)
3179 return emitOpError(
"expected operand #3 of same type as result type");
3183 return emitOpError(
"unsupported outerproduct type");
3192 Type OuterProductOp::getExpectedMaskType() {
3193 auto vecType = this->getResultVectorType();
3196 vecType.getScalableDims());
3205 auto inputVectorType = getInputVectorType();
3206 auto outputVectorType = getOutputVectorType();
3207 int64_t inputShapeRank = getNumInputShapeSizes();
3208 int64_t outputShapeRank = getNumOutputShapeSizes();
3210 getFixedVectorSizes(fixedVectorSizes);
3211 int64_t numFixedVectorSizes = fixedVectorSizes.size();
3213 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
3214 return emitError(
"invalid input shape for vector type ") << inputVectorType;
3216 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
3217 return emitError(
"invalid output shape for vector type ")
3218 << outputVectorType;
3222 unsigned inputVectorRank = inputVectorType.getRank();
3223 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
3224 unsigned index = inputVectorRank - numFixedVectorSizes - i;
3225 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
3226 return emitError(
"fixed vector size must match input vector for dim ")
3230 unsigned outputVectorRank = outputVectorType.getRank();
3231 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
3232 unsigned index = outputVectorRank - numFixedVectorSizes - i;
3233 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
3234 return emitError(
"fixed vector size must match output vector for dim ")
3240 auto isDefByConstant = [](
Value operand) {
3243 if (llvm::all_of(getInputShape(), isDefByConstant) &&
3244 llvm::all_of(getOutputShape(), isDefByConstant)) {
3245 int64_t numInputElements = 1;
3246 for (
auto operand : getInputShape())
3248 int64_t numOutputElements = 1;
3249 for (
auto operand : getOutputShape())
3251 if (numInputElements != numOutputElements)
3252 return emitError(
"product of input and output shape sizes must match");
3270 ArrayAttr offsets, ArrayAttr sizes,
3271 ArrayAttr strides) {
3272 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3274 shape.reserve(vectorType.getRank());
3276 for (
unsigned e = offsets.size(); idx < e; ++idx)
3277 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3278 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3279 shape.push_back(vectorType.getShape()[idx]);
3282 vectorType.getScalableDims());
3295 offsetsAttr, sizesAttr, stridesAttr));
3296 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3300 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3305 auto type = getSourceVectorType();
3306 auto offsets = getOffsetsAttr();
3307 auto sizes = getSizesAttr();
3308 auto strides = getStridesAttr();
3309 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3311 "expected offsets, sizes and strides attributes of same size");
3313 auto shape = type.getShape();
3314 auto offName = getOffsetsAttrName();
3315 auto sizesName = getSizesAttrName();
3316 auto stridesName = getStridesAttrName();
3332 shape, offName, sizesName,
3337 offsets, sizes, strides);
3338 if (getResult().getType() != resultType)
3339 return emitOpError(
"expected result type to be ") << resultType;
3341 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3342 if (type.getScalableDims()[idx]) {
3343 auto inputDim = type.getShape()[idx];
3344 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3345 if (inputDim != inputSize)
3346 return emitOpError(
"expected size at idx=")
3348 << (
" to match the corresponding base size from the input "
3350 << inputSize << (
" vs ") << inputDim << (
")");
3363 auto getElement = [](ArrayAttr array,
int idx) {
3364 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3366 ArrayAttr extractOffsets = op.getOffsets();
3368 ArrayAttr extractSizes = op.getSizes();
3369 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3371 if (op.getSourceVectorType().getRank() !=
3372 insertOp.getSourceVectorType().getRank())
3374 ArrayAttr insertOffsets = insertOp.getOffsets();
3375 ArrayAttr insertStrides = insertOp.getStrides();
3378 if (extractOffsets.size() > insertOffsets.size())
3380 bool patialoverlap =
false;
3381 bool disjoint =
false;
3383 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3384 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3386 int64_t start = getElement(insertOffsets, dim);
3387 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3388 int64_t offset = getElement(extractOffsets, dim);
3389 int64_t size = getElement(extractSizes, dim);
3391 if (start <= offset && offset < end) {
3394 if (offset + size > end)
3395 patialoverlap =
true;
3396 offsetDiffs.push_back(offset - start);
3403 if (!disjoint && !patialoverlap) {
3413 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3423 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3424 if (getSourceVectorType() == getResult().getType())
3439 class StridedSliceConstantMaskFolder final
3448 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3449 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3450 if (!constantMaskOp)
3453 if (extractStridedSliceOp.hasNonUnitStrides())
3467 sliceMaskDimSizes.reserve(maskDimSizes.size());
3468 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3469 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3470 int64_t sliceMaskDimSize =
std::max(
3471 static_cast<int64_t
>(0),
3472 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3473 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3476 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3477 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3478 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3481 if (llvm::is_contained(sliceMaskDimSizes, 0))
3482 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3487 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3494 class StridedSliceSplatConstantFolder final
3503 Value sourceVector = extractStridedSliceOp.getVector();
3508 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3522 class StridedSliceNonSplatConstantFolder final
3531 Value sourceVector = extractStridedSliceOp.getVector();
3537 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3538 if (!dense || dense.isSplat())
3542 if (extractStridedSliceOp.hasNonUnitStrides())
3545 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3549 VectorType sliceVecTy = extractStridedSliceOp.getType();
3551 int64_t sliceRank = sliceVecTy.getRank();
3563 auto denseValuesBegin = dense.value_begin<
Attribute>();
3565 sliceValues.reserve(sliceVecTy.getNumElements());
3568 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3569 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3571 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3575 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3576 sliceVecTy.getNumElements() &&
3577 "Invalid number of slice elements");
3587 class StridedSliceBroadcast final
3594 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3599 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3600 auto dstVecType = llvm::cast<VectorType>(op.getType());
3601 unsigned dstRank = dstVecType.getRank();
3602 unsigned rankDiff = dstRank - srcRank;
3606 bool lowerDimMatch =
true;
3607 for (
unsigned i = 0; i < srcRank; i++) {
3608 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3609 lowerDimMatch =
false;
3618 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3619 if (!lowerDimMatch && !isScalarSrc) {
3620 source = rewriter.
create<ExtractStridedSliceOp>(
3632 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3638 auto splat = op.getVector().getDefiningOp<SplatOp>();
3648 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3652 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3653 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3654 StridedSliceSplat>(context);
3663 VectorType vectorType,
Value source,
3664 ValueRange indices, AffineMapAttr permutationMapAttr,
3665 ArrayAttr inBoundsAttr) {
3666 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3667 Value padding = builder.
create<arith::ConstantOp>(
3669 build(builder, result, vectorType, source, indices, permutationMapAttr,
3670 padding,
Value(), inBoundsAttr);
3675 VectorType vectorType,
Value source,
3679 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3682 build(builder, result, vectorType, source, indices, permutationMapAttr,
3688 VectorType vectorType,
Value source,
3692 llvm::cast<ShapedType>(source.
getType()), vectorType);
3694 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3697 build(builder, result, vectorType, source, indices, permutationMapAttr,
3699 Value(), inBoundsAttr);
3705 VectorType vectorType,
Value source,
3708 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3709 Value padding = builder.
create<arith::ConstantOp>(
3711 build(builder, result, vectorType, source, indices, padding, inBounds);
3714 template <
typename EmitFun>
3716 EmitFun emitOpError) {
3718 for (
auto expr : permutationMap.
getResults()) {
3719 auto dim = dyn_cast<AffineDimExpr>(expr);
3720 auto zero = dyn_cast<AffineConstantExpr>(expr);
3722 if (zero.getValue() != 0) {
3724 "requires a projected permutation_map (at most one dim or the zero "
3725 "constant can appear in each result)");
3730 return emitOpError(
"requires a projected permutation_map (at most one "
3731 "dim or the zero constant can appear in each result)");
3733 if (seen[dim.getPosition()]) {
3735 "requires a permutation_map that is a permutation (found one dim "
3736 "used more than once)");
3738 seen[dim.getPosition()] =
true;
3745 VectorType vectorType, VectorType maskType,
3746 VectorType inferredMaskType,
AffineMap permutationMap,
3747 ArrayAttr inBounds) {
3749 return op->
emitOpError(
"masked attribute has been removed. "
3750 "Use in_bounds instead.");
3753 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3755 "requires source to be a memref or ranked tensor type");
3757 auto elementType = shapedType.getElementType();
3758 DataLayout dataLayout = DataLayout::closest(op);
3759 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3761 unsigned sourceVecSize =
3763 vectorElementType.getShape().back();
3764 unsigned resultVecSize =
3766 vectorType.getShape().back();
3767 if (resultVecSize % sourceVecSize != 0)
3769 "requires the bitwidth of the minor 1-D vector to be an integral "
3770 "multiple of the bitwidth of the minor 1-D vector of the source");
3772 unsigned sourceVecEltRank = vectorElementType.getRank();
3773 unsigned resultVecRank = vectorType.getRank();
3774 if (sourceVecEltRank > resultVecRank)
3776 "requires source vector element and vector result ranks to match.");
3777 unsigned rankOffset = resultVecRank - sourceVecEltRank;
3780 return op->
emitOpError(
"requires a permutation_map with result dims of "
3781 "the same rank as the vector type");
3784 return op->
emitOpError(
"does not support masks with vector element type");
3787 unsigned minorSize =
3788 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3789 unsigned resultVecSize =
3793 "requires the bitwidth of the minor 1-D vector to be an integral "
3794 "multiple of the bitwidth of the source element type");
3798 return op->
emitOpError(
"requires a permutation_map with result dims of "
3799 "the same rank as the vector type");
3803 return op->
emitOpError(
"requires permutation_map without symbols");
3805 if (permutationMap.
getNumInputs() != shapedType.getRank())
3806 return op->
emitOpError(
"requires a permutation_map with input dims of the "
3807 "same rank as the source type");
3809 if (maskType && maskType != inferredMaskType)
3811 << inferredMaskType <<
") and mask operand type (" << maskType
3815 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
3816 return op->
emitOpError(
"expects the optional in_bounds attr of same rank "
3817 "as permutation_map results: ")
3819 <<
" vs inBounds of size: " << inBounds.size();
3820 for (
unsigned int i = 0; i < permutationMap.
getNumResults(); ++i)
3821 if (isa<AffineConstantExpr>(permutationMap.
getResult(i)) &&
3822 !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3823 return op->
emitOpError(
"requires broadcast dimensions to be in-bounds");
3831 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3832 if (op.getPermutationMap().isMinorIdentity())
3833 elidedAttrs.push_back(op.getPermutationMapAttrName());
3835 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
3836 elidedAttrs.push_back(op.getInBoundsAttrName());
3841 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
3843 p <<
", " << getMask();
3852 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
3882 if (types.size() != 2)
3883 return parser.
emitError(typesLoc,
"requires two types");
3885 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
3886 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
3887 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
3888 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
3890 return parser.
emitError(typesLoc,
"requires vector type");
3891 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
3898 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
3906 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
3908 maskInfo.
location,
"does not support masks with vector element type");
3911 "expected the same rank for the vector and the "
3912 "results of the permutation map");
3920 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3922 {1, static_cast<int32_t>(indexInfo.size()), 1,
3923 static_cast<int32_t>(hasMask.succeeded())}));
3929 ShapedType shapedType = getShapedType();
3931 VectorType maskType = getMaskType();
3932 auto paddingType = getPadding().getType();
3933 auto permutationMap = getPermutationMap();
3934 VectorType inferredMaskType =
3937 auto sourceElementType = shapedType.getElementType();
3939 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
3940 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
3943 shapedType, vectorType, maskType,
3944 inferredMaskType, permutationMap,
3945 getInBounds() ? *getInBounds() : ArrayAttr())))
3948 if (
auto sourceVectorElementType =
3949 llvm::dyn_cast<VectorType>(sourceElementType)) {
3952 if (sourceVectorElementType != paddingType)
3954 "requires source element type and padding type to match.");
3958 if (!VectorType::isValidElementType(paddingType))
3959 return emitOpError(
"requires valid padding vector elemental type");
3962 if (paddingType != sourceElementType)
3964 "requires formal padding and source of the same elemental type");
3968 [&](Twine t) {
return emitOpError(t); });
3975 Type TransferReadOp::getExpectedMaskType() {
3979 template <
typename TransferOp>
3980 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3983 if (op.getShapedType().isDynamicDim(indicesIdx))
3985 Value index = op.getIndices()[indicesIdx];
3987 if (!cstOp.has_value())
3990 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3991 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3993 return cstOp.value() + vectorSize <= sourceSize;
3996 template <
typename TransferOp>
4000 if (op.getTransferRank() == 0)
4002 AffineMap permutationMap = op.getPermutationMap();
4003 bool changed =
false;
4005 newInBounds.reserve(op.getTransferRank());
4006 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4008 if (op.isDimInBounds(i)) {
4009 newInBounds.push_back(
true);
4014 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4015 assert(dimExpr &&
"Broadcast dims must be in-bounds");
4018 newInBounds.push_back(inBounds);
4020 changed |= inBounds;
4030 template <
typename TransferOp>
4032 auto mask = op.getMask();
4036 auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
4040 if (!constantMask.isAllOnesMask())
4043 op.getMaskMutable().clear();
4057 static Value foldRAW(TransferReadOp readOp) {
4058 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4060 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4063 return defWrite.getVector();
4065 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4066 cast<VectorTransferOpInterface>(readOp.getOperation())))
4068 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4074 if (
Value vec = foldRAW(*
this))
4088 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4092 void TransferReadOp::getEffects(
4095 if (llvm::isa<MemRefType>(getShapedType()))
4123 struct TransferReadAfterWriteToBroadcast
4129 if (readOp.hasOutOfBoundsDim() ||
4130 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4132 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4137 if (readOp.getTransferChunkAccessed() !=
4138 defWrite.getTransferChunkAccessed())
4145 if (readOp.getIndices() != defWrite.getIndices() ||
4146 readOp.getMask() != defWrite.getMask())
4148 Value vec = defWrite.getVector();
4170 broadcastShape[pos.value()] = destShape[pos.index()];
4171 broadcastScalableFlags[pos.value()] =
4172 readOp.getVectorType().getScalableDims()[pos.index()];
4175 broadcastShape, defWrite.getVectorType().getElementType(),
4176 broadcastScalableFlags);
4177 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4188 results.
add<TransferReadAfterWriteToBroadcast>(context);
4198 AffineMapAttr permutationMapAttr,
4200 ArrayAttr inBoundsAttr) {
4201 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4202 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4203 mask, inBoundsAttr);
4209 AffineMapAttr permutationMapAttr,
4210 ArrayAttr inBoundsAttr) {
4211 build(builder, result, vector, dest, indices, permutationMapAttr,
4212 Value(), inBoundsAttr);
4222 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4225 build(builder, result, vector, dest, indices, permutationMapAttr,
4226 Value(), inBoundsAttr);
4234 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4236 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4237 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4258 if (types.size() != 2)
4259 return parser.
emitError(typesLoc,
"requires two types");
4261 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4263 return parser.
emitError(typesLoc,
"requires vector type");
4264 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4265 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4266 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4267 auto permMapAttrName =
4268 TransferWriteOp::getPermutationMapAttrName(result.
name);
4275 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4282 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4284 maskInfo.
location,
"does not support masks with vector element type");
4287 "expected the same rank for the vector and the "
4288 "results of the permutation map");
4294 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4296 {1, 1, static_cast<int32_t>(indexInfo.size()),
4297 static_cast<int32_t>(hasMask.succeeded())}));
4298 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4303 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4305 p <<
", " << getMask();
4312 ShapedType shapedType = getShapedType();
4314 VectorType maskType = getMaskType();
4315 auto permutationMap = getPermutationMap();
4316 VectorType inferredMaskType =
4320 if (llvm::size(
getIndices()) != shapedType.getRank())
4321 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4325 if (hasBroadcastDim())
4326 return emitOpError(
"should not have broadcast dimensions");
4329 shapedType, vectorType, maskType,
4330 inferredMaskType, permutationMap,
4331 getInBounds() ? *getInBounds() : ArrayAttr())))
4335 [&](Twine t) {
return emitOpError(t); });
4342 Type TransferWriteOp::getExpectedMaskType() {
4363 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4367 if (write.getTransferRank() == 0)
4369 auto rankedTensorType =
4370 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4372 if (!rankedTensorType)
4375 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4379 if (read.getTransferRank() == 0)
4382 if (!read.getPermutationMap().isMinorIdentity() ||
4383 !write.getPermutationMap().isMinorIdentity())
4386 if (read.getTransferRank() != write.getTransferRank())
4389 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4392 if (read.getSource().getType() != rankedTensorType)
4395 if (read.getVectorType() != write.getVectorType())
4398 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4401 auto isNotConstantZero = [](
Value v) {
4403 return !cstOp.has_value() || cstOp.value() != 0;
4405 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4406 llvm::any_of(write.getIndices(), isNotConstantZero))
4409 results.push_back(read.getSource());
4413 static bool checkSameValueWAR(vector::TransferReadOp read,
4414 vector::TransferWriteOp write) {
4415 return read.getSource() == write.getSource() &&
4416 read.getIndices() == write.getIndices() &&
4417 read.getPermutationMap() == write.getPermutationMap() &&
4418 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4437 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4439 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4443 if (!checkSameValueWAR(read, write))
4445 results.push_back(read.getSource());
4451 if (
succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4462 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4466 void TransferWriteOp::getEffects(
4469 if (llvm::isa<MemRefType>(getShapedType()))
4504 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4506 vector::TransferWriteOp writeToModify = writeOp;
4509 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4513 writeToModify.getSourceMutable().assign(defWrite.getSource());
4518 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4519 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4523 if (!defWrite->hasOneUse())
4525 writeToModify = defWrite;
4526 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4555 struct SwapExtractSliceOfTransferWrite
4562 if (!insertOp.hasUnitStride())
4565 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4566 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4568 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4569 if (!transferOp || !transferOp->hasOneUse())
4574 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4576 "use-def chain is rank-reducing");
4580 if (!extractOp.hasZeroOffset()) {
4582 "ExtractSliceOp has non-zero offset");
4586 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
4590 "TranferWriteOp has non-zero offset");
4594 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4596 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
4599 for (
auto [insertSize, extractSize] :
4600 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4603 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
4608 assert(transferOp.getVectorType().hasStaticShape() &&
4609 "expected vector to have a static shape");
4612 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4613 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
4615 insertOp,
"TransferWriteOp may not write the full tensor.");
4621 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
4622 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4623 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4624 insertOp.getMixedStrides());
4625 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
4626 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4627 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4630 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4640 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4648 MemRefType memRefTy) {
4650 return op->
emitOpError(
"most minor memref dim must have unit stride");
4658 if (
failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4662 Type memElemTy = memRefTy.getElementType();
4663 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4664 if (memVecTy != resVecTy)
4665 return emitOpError(
"base memref and result vector types should match");
4666 memElemTy = memVecTy.getElementType();
4669 if (resVecTy.getElementType() != memElemTy)
4670 return emitOpError(
"base and result element types should match");
4671 if (llvm::size(
getIndices()) != memRefTy.getRank())
4672 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4690 if (
failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4694 Type memElemTy = memRefTy.getElementType();
4695 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4696 if (memVecTy != valueVecTy)
4698 "base memref and valueToStore vector types should match");
4699 memElemTy = memVecTy.getElementType();
4702 if (valueVecTy.getElementType() != memElemTy)
4703 return emitOpError(
"base and valueToStore element type should match");
4704 if (llvm::size(
getIndices()) != memRefTy.getRank())
4705 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4719 VectorType maskVType = getMaskVectorType();
4720 VectorType passVType = getPassThruVectorType();
4724 if (resVType.getElementType() != memType.getElementType())
4725 return emitOpError(
"base and result element type should match");
4726 if (llvm::size(
getIndices()) != memType.getRank())
4727 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4728 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4729 return emitOpError(
"expected result dim to match mask dim");
4730 if (resVType != passVType)
4731 return emitOpError(
"expected pass_thru of same type as result type");
4744 load, load.getType(), load.getBase(), load.getIndices());
4747 rewriter.
replaceOp(load, load.getPassThru());
4752 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
4759 results.
add<MaskedLoadFolder>(context);
4773 VectorType maskVType = getMaskVectorType();
4777 if (valueVType.getElementType() != memType.getElementType())
4778 return emitOpError(
"base and valueToStore element type should match");
4779 if (llvm::size(
getIndices()) != memType.getRank())
4780 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4781 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4782 return emitOpError(
"expected valueToStore dim to match mask dim");
4795 store, store.getValueToStore(), store.getBase(), store.getIndices());
4803 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
4810 results.
add<MaskedStoreFolder>(context);
4823 VectorType indVType = getIndexVectorType();
4824 VectorType maskVType = getMaskVectorType();
4826 ShapedType baseType = getBaseType();
4828 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4829 return emitOpError(
"requires base to be a memref or ranked tensor type");
4831 if (resVType.getElementType() != baseType.getElementType())
4832 return emitOpError(
"base and result element type should match");
4833 if (llvm::size(
getIndices()) != baseType.getRank())
4834 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
4835 if (resVType.getShape() != indVType.getShape())
4836 return emitOpError(
"expected result dim to match indices dim");
4837 if (resVType.getShape() != maskVType.getShape())
4838 return emitOpError(
"expected result dim to match mask dim");
4839 if (resVType != getPassThruVectorType())
4840 return emitOpError(
"expected pass_thru of same type as result type");
4848 Type GatherOp::getExpectedMaskType() {
4849 auto vecType = this->getIndexVectorType();
4852 vecType.getScalableDims());
4855 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
4869 rewriter.
replaceOp(gather, gather.getPassThru());
4874 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
4881 results.
add<GatherFolder>(context);
4889 VectorType indVType = getIndexVectorType();
4890 VectorType maskVType = getMaskVectorType();
4894 if (valueVType.getElementType() != memType.getElementType())
4895 return emitOpError(
"base and valueToStore element type should match");
4896 if (llvm::size(
getIndices()) != memType.getRank())
4897 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4898 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4899 return emitOpError(
"expected valueToStore dim to match indices dim");
4900 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4901 return emitOpError(
"expected valueToStore dim to match mask dim");
4920 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
4927 results.
add<ScatterFolder>(context);
4935 VectorType maskVType = getMaskVectorType();
4936 VectorType passVType = getPassThruVectorType();
4940 if (resVType.getElementType() != memType.getElementType())
4941 return emitOpError(
"base and result element type should match");
4942 if (llvm::size(
getIndices()) != memType.getRank())
4943 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4944 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4945 return emitOpError(
"expected result dim to match mask dim");
4946 if (resVType != passVType)
4947 return emitOpError(
"expected pass_thru of same type as result type");
4960 expand, expand.getType(), expand.getBase(), expand.getIndices());
4963 rewriter.
replaceOp(expand, expand.getPassThru());
4968 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
4975 results.
add<ExpandLoadFolder>(context);
4983 VectorType maskVType = getMaskVectorType();
4987 if (valueVType.getElementType() != memType.getElementType())
4988 return emitOpError(
"base and valueToStore element type should match");
4989 if (llvm::size(
getIndices()) != memType.getRank())
4990 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4991 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4992 return emitOpError(
"expected valueToStore dim to match mask dim");
4997 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5005 compress, compress.getValueToStore(), compress.getBase(),
5006 compress.getIndices());
5014 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5021 results.
add<CompressStoreFolder>(context);
5031 unsigned rankA = a.size();
5032 unsigned rankB = b.size();
5033 assert(rankA < rankB);
5035 auto isOne = [](int64_t v) {
return v == 1; };
5039 if (rankA == 0 && llvm::all_of(b, isOne))
5044 while (i < rankA &&
j < rankB) {
5045 int64_t dimA = a[i];
5047 while (dimB < dimA &&
j < rankB)
5055 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5057 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5061 return i == rankA &&
j == rankB;
5065 VectorType sourceVectorType,
5066 VectorType resultVectorType) {
5068 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5069 return op->
emitOpError(
"source/result vectors must have same element type");
5070 auto sourceShape = sourceVectorType.getShape();
5071 auto resultShape = resultVectorType.getShape();
5074 int64_t sourceDimProduct = std::accumulate(
5075 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5076 int64_t resultDimProduct = std::accumulate(
5077 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5078 if (sourceDimProduct != resultDimProduct)
5079 return op->
emitOpError(
"source/result number of elements must match");
5082 unsigned sourceRank = sourceVectorType.getRank();
5083 unsigned resultRank = resultVectorType.getRank();
5084 if (sourceRank < resultRank) {
5085 if (!isValidShapeCast(sourceShape, resultShape))
5087 }
else if (sourceRank > resultRank) {
5088 if (!isValidShapeCast(resultShape, sourceShape))
5095 auto sourceVectorType =
5096 llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5097 auto resultVectorType =
5098 llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5101 if (sourceVectorType && resultVectorType)
5102 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5109 if (getSource().getType() == getResult().getType())
5113 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5114 if (getResult().getType() == otherOp.getSource().getType())
5115 return otherOp.getSource();
5118 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5119 VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5120 if (srcType.getRank() < resultType.getRank()) {
5121 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5123 }
else if (srcType.getRank() > resultType.getRank()) {
5124 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5130 setOperand(otherOp.getSource());
5135 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5136 if (bcastOp.getSourceType() == getType())
5137 return bcastOp.getSource();
5145 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5152 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5156 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5172 static VectorType trimTrailingOneDims(VectorType oldType) {
5179 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5180 newShape = newShape.drop_back(1);
5181 newScalableDims = newScalableDims.drop_back(1);
5186 if (newShape.empty()) {
5187 newShape = oldShape.take_back();
5188 newScalableDims = oldScalableDims.take_back();
5191 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5206 class ShapeCastCreateMaskFolderTrailingOneDim final
5213 Value shapeOpSrc = shapeOp->getOperand(0);
5214 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5215 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5216 if (!createMaskOp && !constantMaskOp)
5219 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5220 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5222 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5223 if (newVecType != shapeOpResTy)
5226 auto numDimsToDrop =
5227 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5234 auto maskOperands = createMaskOp.getOperands();
5235 auto numMaskOperands = maskOperands.size();
5238 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5240 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5241 if (!constant || (constant.value() != 1))
5245 maskOperands.drop_back(numDimsToDrop);
5252 if (constantMaskOp) {
5253 auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
5254 auto numMaskOperands = maskDimSizes.size();
5257 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5259 if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
5263 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5264 ArrayAttr newMaskOperandsAttr = rewriter.
getArrayAttr(newMaskOperands);
5267 newMaskOperandsAttr);
5280 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5287 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5292 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5293 broadcastSourceShape = srcType.getShape();
5295 shapeCastOp.getResultVectorType().getShape();
5299 if (broadcastSourceShape ==
5300 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5302 shapeCastOp, shapeCastOp.getResultVectorType(),
5303 broadcastOp.getSource());
5309 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5310 if (srcType.getNumElements() ==
5311 shapeCastOp.getResultVectorType().getNumElements()) {
5313 shapeCastOp, shapeCastOp.getResultVectorType(),
5314 broadcastOp.getSource());
5327 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5328 ShapeCastBroadcastFolder>(context);
5336 auto sourceVectorType = getSourceVectorType();
5337 auto resultVectorType = getResultVectorType();
5339 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5340 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5341 return emitOpError(
"dimension size mismatch at: ") << i;
5344 DataLayout dataLayout = DataLayout::closest(*
this);
5345 auto sourceElementBits =
5347 auto resultElementBits =
5350 if (sourceVectorType.getRank() == 0) {
5351 if (sourceElementBits != resultElementBits)
5352 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5353 "types must be equal");
5354 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5355 resultElementBits * resultVectorType.getShape().back()) {
5357 "source/result bitwidth of the minor 1-D vectors must be equal");
5365 if (getSource().getType() == getResult().getType())
5369 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5370 if (getResult().getType() == otherOp.getSource().getType())
5371 return otherOp.getSource();
5373 setOperand(otherOp.getSource());
5377 Attribute sourceConstant = adaptor.getSource();
5378 if (!sourceConstant)
5381 Type srcElemType = getSourceVectorType().getElementType();
5382 Type dstElemType = getResultVectorType().getElementType();
5384 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5385 if (floatPack.isSplat()) {
5386 auto splat = floatPack.getSplatValue<FloatAttr>();
5389 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5390 uint32_t bits =
static_cast<uint32_t
>(
5391 splat.getValue().bitcastToAPInt().getZExtValue());
5393 bits = (bits << 16) | (bits & 0xffff);
5394 APInt intBits(32, bits);
5395 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5401 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5402 if (intPack.isSplat()) {
5403 auto splat = intPack.getSplatValue<IntegerAttr>();
5405 if (llvm::isa<IntegerType>(dstElemType)) {
5410 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5411 APInt intBits = splat.getValue().zext(dstBitWidth);
5414 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5415 intBits = (intBits << srcBitWidth) | intBits;
5430 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5432 memRefType.getShape().end());
5434 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5443 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5444 VectorType vectorType =
5448 memRefType.getMemorySpace()));
5453 if (!canonicalType.getLayout().isIdentity())
5454 return emitOpError(
"expects operand to be a memref with identity layout");
5455 if (!getResultMemRefType().getLayout().isIdentity())
5456 return emitOpError(
"expects result to be a memref with identity layout");
5457 if (getResultMemRefType().getMemorySpace() !=
5459 return emitOpError(
"expects result in same memory space");
5462 auto resultType = getResultMemRefType();
5466 "expects result and operand with same underlying scalar type: ")
5468 if (extractShape(sourceType) != extractShape(resultType))
5470 "expects concatenated result and operand shapes to be equal: ")
5481 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5484 for (
unsigned i = 0; i < permutation.size(); ++i) {
5485 transposedShape[i] = vt.getShape()[permutation[i]];
5486 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5491 transposedScalableDims));
5496 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5499 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5501 return attr.reshape(getResultVectorType());
5509 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5518 VectorType vectorType = getSourceVectorType();
5519 VectorType resultType = getResultVectorType();
5520 int64_t rank = resultType.getRank();
5521 if (vectorType.getRank() != rank)
5522 return emitOpError(
"vector result rank mismatch: ") << rank;
5525 int64_t size = perm.size();
5527 return emitOpError(
"transposition length mismatch: ") << size;
5530 if (ta.value() < 0 || ta.value() >= rank)
5531 return emitOpError(
"transposition index out of range: ") << ta.value();
5532 if (seen[ta.value()])
5533 return emitOpError(
"duplicate position index: ") << ta.value();
5534 seen[ta.value()] =
true;
5535 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5536 return emitOpError(
"dimension size mismatch at: ") << ta.value();
5541 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5542 return llvm::to_vector<4>(getResultVectorType().
getShape());
5548 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
5558 for (
auto index : permutation2)
5559 result.push_back(permutation1[index]);
5564 vector::TransposeOp parentTransposeOp =
5565 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5566 if (!parentTransposeOp)
5570 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5573 transposeOp, transposeOp.getResult().getType(),
5574 parentTransposeOp.getVector(), permutation);
5580 struct FoldTransposedScalarBroadcast final
5586 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5590 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5591 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5593 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5608 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5613 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5619 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
5625 Value transposeSrc = transpOp.getVector();
5626 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
5627 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
5628 if (!createMaskOp && !constantMaskOp)
5636 auto maskOperands = createMaskOp.getOperands();
5641 transpOp, transpOp.getResultVectorType(), newOperands);
5646 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5651 transpOp, transpOp.getResultVectorType(),
5659 void vector::TransposeOp::getCanonicalizationPatterns(
5661 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5662 TransposeFolder, FoldTransposeSplat>(context);
5670 auto resultType = llvm::cast<VectorType>(getResult().getType());
5672 if (resultType.getRank() == 0) {
5673 if (getMaskDimSizes().size() != 1)
5674 return emitError(
"array attr must have length 1 for 0-D vectors");
5675 auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5676 if (dim != 0 && dim != 1)
5677 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
5682 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
5684 "must specify array attr of size equal vector result rank");
5687 auto resultShape = resultType.getShape();
5688 auto resultScalableDims = resultType.getScalableDims();
5690 for (
const auto [index, intAttr] :
llvm::enumerate(getMaskDimSizes())) {
5691 int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5692 if (maskDimSize < 0 || maskDimSize > resultShape[index])
5694 "array attr of size out of bounds of vector result dimension size");
5695 if (resultScalableDims[index] && maskDimSize != 0 &&
5696 maskDimSize != resultShape[index])
5698 "only supports 'none set' or 'all set' scalable dimensions");
5699 maskDimSizes.push_back(maskDimSize);
5703 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5704 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
5705 if (anyZeros && !allZeros)
5706 return emitOpError(
"expected all mask dim sizes to be zeros, "
5707 "as a result of conjunction with zero mask dim");
5711 bool ConstantMaskOp::isAllOnesMask() {
5714 if (resultType.getRank() == 0) {
5715 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
5716 return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5718 for (
const auto [resultSize, intAttr] :
5719 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5720 int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5721 if (maskDimSize < resultSize)
5736 build(builder, result, type, operands);
5740 auto vectorType = llvm::cast<VectorType>(getResult().getType());
5742 if (vectorType.getRank() == 0) {
5743 if (getNumOperands() != 1)
5745 "must specify exactly one operand for 0-D create_mask");
5746 }
else if (getNumOperands() !=
5747 llvm::cast<VectorType>(getResult().getType()).getRank()) {
5749 "must specify an operand for each result vector dimension");
5785 VectorType retTy = createMaskOp.getResult().getType();
5786 bool isScalable = retTy.isScalable();
5789 for (
auto [opIdx, operand] :
llvm::enumerate(createMaskOp.getOperands())) {
5794 if (retTy.getScalableDims()[opIdx] && *cst > 0)
5809 auto mulLHS = mul.getRhs();
5810 auto mulRHS = mul.getLhs();
5811 bool isOneOpVscale =
5812 (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5813 isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5815 auto isConstantValMatchingDim =
5816 [=, dim = retTy.getShape()[opIdx]](
Value operand) {
5818 return (constantVal.has_value() && constantVal.value() == dim);
5821 bool isOneOpConstantMatchingDim =
5822 isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5824 if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5830 maskDimSizes.reserve(createMaskOp->getNumOperands());
5831 for (
auto [operand, maxDimSize] : llvm::zip_equal(
5832 createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5837 maskDimSizes.push_back(maxDimSize);
5840 int64_t dimSizeVal =
std::min(dimSize.value(), maxDimSize);
5843 maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5846 maskDimSizes.push_back(dimSizeVal);
5851 createMaskOp, retTy,
5861 results.
add<CreateMaskFolder>(context);
5872 assert(maskRegionBuilder &&
5873 "builder callback for 'maskRegion' must be present");
5879 maskRegionBuilder(builder, maskableOp);
5886 build(builder, result, resultTypes, mask,
Value(), maskableOp,
5894 build(builder, result, mask, maskableOp, maskRegionBuilder);
5922 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
5936 result.
types.append(resultTypes);
5950 p <<
" " << getMask();
5952 p <<
", " << getPassthru();
5956 Block *singleBlock = &getMaskRegion().getBlocks().
front();
5963 p <<
" : " << getMask().getType();
5964 if (getNumResults() > 0)
5965 p <<
" -> " << getResultTypes();
5970 MaskOp>::ensureTerminator(region, builder, loc);
5982 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
5985 if (maskedOp == oldYieldOp)
5988 opBuilder.setInsertionPoint(oldYieldOp);
5989 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
5991 oldYieldOp->
erase();
5996 Block &block = getMaskRegion().getBlocks().
front();
5998 return emitOpError(
"expects a terminator within the mask region");
6000 return emitOpError(
"expects only one operation to mask");
6003 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6005 return emitOpError(
"expects a terminator within the mask region");
6007 if (terminator->getNumOperands() != getNumResults())
6009 "expects number of results to match mask region yielded values");
6011 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6018 return emitOpError(
"expects number of results to match maskable operation "
6019 "number of results");
6021 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6023 "expects result type to match maskable operation result type");
6026 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6027 return emitOpError(
"multiple vector results not supported");
6030 Type expectedMaskType = maskableOp.getExpectedMaskType();
6031 if (getMask().getType() != expectedMaskType)
6032 return emitOpError(
"expects a ")
6033 << expectedMaskType <<
" mask for the maskable operation";
6036 Value passthru = getPassthru();
6038 if (!maskableOp.supportsPassthru())
6040 "doesn't expect a passthru argument for this maskable operation");
6043 return emitOpError(
"expects result when passthru argument is provided");
6046 return emitOpError(
"expects passthru type to match result type");
6063 Operation *maskableOp = getMaskableOp();
6067 llvm::append_range(results, maskableOp->
getResults());
6079 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6080 if (maskingOp.getMaskableOp())
6083 if (!maskOp.isEmpty())
6086 Block *block = maskOp.getMaskBlock();
6087 auto terminator = cast<vector::YieldOp>(block->
front());
6088 if (terminator.getNumOperands() == 0)
6091 rewriter.
replaceOp(maskOp, terminator.getOperands());
6099 results.
add<ElideEmptyMaskOp>(context);
6106 Block *block = getMaskBlock();
6110 return &block->
front();
6114 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
6121 VectorType srcType = getSourceType();
6122 VectorType initialType = getInitialValueType();
6124 int64_t srcRank = srcType.getRank();
6125 int64_t reductionDim = getReductionDim();
6126 if (reductionDim >= srcRank)
6127 return emitOpError(
"reduction dimension ")
6128 << reductionDim <<
" has to be less than " << srcRank;
6131 int64_t initialValueRank = initialType.getRank();
6132 if (initialValueRank != srcRank - 1)
6133 return emitOpError(
"initial value rank ")
6134 << initialValueRank <<
" has to be equal to " << srcRank - 1;
6140 for (
int i = 0; i < srcRank; i++) {
6141 if (i != reductionDim)
6142 expectedShape.push_back(srcShape[i]);
6144 if (!llvm::equal(initialValueShapes, expectedShape)) {
6145 return emitOpError(
"incompatible input/initial value shapes");
6149 Type eltType = getDestType().getElementType();
6151 return emitOpError(
"unsupported reduction type ")
6152 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
6161 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6162 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6163 StridedSliceConstantMaskFolder, TransposeFolder>(
6172 auto constOperand = adaptor.getInput();
6173 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
6185 p <<
"(" << getLaneid() <<
")";
6188 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
6189 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
6191 if (!getArgs().empty())
6192 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
6193 if (!getResults().empty())
6194 p <<
" -> (" << getResults().getTypes() <<
')';
6198 !getResults().empty());
6228 llvm::SMLoc inputsOperandsLoc;
6240 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
6251 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.
location);
6259 void WarpExecuteOnLane0Op::getSuccessorRegions(
6273 build(builder, result, resultTypes, laneId, warpSize,
6274 std::nullopt, std::nullopt);
6286 assert(args.size() == blockArgTypes.size());
6290 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
6299 if (expanded == distributed)
6301 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6302 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6303 if (!expandedVecType || !distributedVecType)
6304 return op->
emitOpError(
"expected vector type for distributed operands.");
6305 if (expandedVecType.getRank() != distributedVecType.getRank() ||
6306 expandedVecType.getElementType() != distributedVecType.getElementType())
6308 "expected distributed vectors to have same rank and element type.");
6311 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
6312 int64_t eDim = expandedVecType.getDimSize(i);
6313 int64_t dDim = distributedVecType.getDimSize(i);
6316 if (eDim % dDim != 0)
6318 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
6319 <<
") to be a multipler of the distributed vector dimension ("
6321 scales[i] = eDim / dDim;
6323 if (std::accumulate(scales.begin(), scales.end(), 1,
6324 std::multiplies<int64_t>()) != warpSize)
6326 <<
"incompatible distribution dimensions from " << expandedVecType
6327 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
6333 if (getArgs().size() != getWarpRegion().getNumArguments())
6335 "expected same number op arguments and block arguments.");
6337 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
6338 if (yield.getNumOperands() != getNumResults())
6340 "expected same number of yield operands and return values.");
6341 int64_t warpSize = getWarpSize();
6342 for (
auto [regionArg, arg] :
6343 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
6344 if (
failed(verifyDistributedType(regionArg.getType(), arg.getType(),
6345 warpSize, getOperation())))
6348 for (
auto [yieldOperand, result] :
6349 llvm::zip_equal(yield.getOperands(), getResults())) {
6350 if (
failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
6351 warpSize, getOperation())))
6357 bool WarpExecuteOnLane0Op::areTypesCompatible(
Type lhs,
Type rhs) {
6359 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
6364 arith::FastMathFlagsAttr fastmath,
6371 case CombiningKind::ADD:
6374 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6375 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6377 llvm_unreachable(
"invalid value types for ADD reduction");
6379 case CombiningKind::AND:
6383 case CombiningKind::MAXNUMF:
6384 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6385 "expected float values");
6386 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6388 case CombiningKind::MAXIMUMF:
6389 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6390 "expected float values");
6391 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6393 case CombiningKind::MINNUMF:
6394 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6395 "expected float values");
6396 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6398 case CombiningKind::MINIMUMF:
6399 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6400 "expected float values");
6401 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6403 case CombiningKind::MAXSI:
6407 case CombiningKind::MINSI:
6411 case CombiningKind::MAXUI:
6419 case CombiningKind::MUL:
6422 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6423 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6425 llvm_unreachable(
"invalid value types for MUL reduction");
6427 case CombiningKind::OR:
6431 case CombiningKind::XOR:
6437 assert(result &&
"unknown CombiningKind");
6449 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
6469 return builder.
create<MaskOp>(maskableOp->getLoc(),
6470 maskableOp->getResultTypes(), mask, maskableOp,
6487 mask, newValue, passthru);
6494 #define GET_ATTRDEF_CLASSES
6495 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6497 #define GET_OP_CLASSES
6498 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static VectorShape vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute whether the given values/dimensions are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
std::optional< Operation::operand_range > getIndices(Operation *op)
Get and set the indices that the given load/store operation is operating on.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.