41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
74 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76 for (
bool b : denseElts.getValues<
bool>())
79 else if (!b && val <= 0)
93 auto shape = m.getType().getShape();
96 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97 if (maskIdx < dimSize)
110 auto maskOperands = m.getOperands();
111 for (
Value operand : maskOperands) {
112 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
127 vector::YieldOp::create(builder, loc);
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(elementType);
179 VectorType vectorType) {
180 unsigned elementVectorRank = 0;
181 VectorType elementVectorType =
182 llvm::dyn_cast<VectorType>(shapedType.getElementType());
183 if (elementVectorType)
184 elementVectorRank += elementVectorType.getRank();
185 return vectorType.getRank() - elementVectorRank;
189 VectorType vectorType) {
192 if (shapedType.getRank() == 0 &&
198 shapedType.getRank(),
200 shapedType.getContext());
207 vector::TransferReadOp read) {
208 auto readMask = read.getMask();
209 auto writeMask = write.getMask();
215 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
216 if (!couldBeSameSplat)
221 m_Constant<DenseElementsAttr>(&splatAttr)) ||
233 vector::TransferReadOp read) {
234 return !defWrite.hasOutOfBoundsDim() &&
235 defWrite.getIndices() == read.getIndices() &&
236 defWrite.getVectorType() == read.getVectorType() &&
237 defWrite.getPermutationMap() == read.getPermutationMap() &&
238 ((!defWrite.getMask() && !read.getMask()) ||
243 vector::TransferWriteOp priorWrite) {
244 return priorWrite.getIndices() == write.getIndices() &&
245 priorWrite.getMask() == write.getMask() &&
246 priorWrite.getVectorType() == write.getVectorType() &&
247 priorWrite.getPermutationMap() == write.getPermutationMap();
251 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
252 bool testDynamicValueUsingBounds) {
254 if (transferA.getVectorType() != transferB.getVectorType())
256 unsigned rankOffset = transferA.getLeadingShapedRank();
257 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
258 Value indexA = transferA.getIndices()[i];
259 Value indexB = transferB.getIndices()[i];
263 if (i < rankOffset) {
266 if (cstIndexA.has_value() && cstIndexB.has_value()) {
267 if (*cstIndexA != *cstIndexB)
271 if (testDynamicValueUsingBounds) {
274 FailureOr<uint64_t> delta =
276 if (succeeded(delta) && *delta != 0)
279 FailureOr<bool> testEqual =
281 if (succeeded(testEqual) && !testEqual.value())
287 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
288 if (cstIndexA.has_value() && cstIndexB.has_value()) {
289 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
290 if (distance >= vectorDim)
294 if (testDynamicValueUsingBounds) {
297 FailureOr<int64_t> delta =
299 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
302 FailureOr<int64_t> computeDelta =
304 if (succeeded(computeDelta)) {
305 if (
std::abs(computeDelta.value()) >= vectorDim)
315 VectorTransferOpInterface transferB,
316 bool testDynamicValueUsingBounds) {
317 if (transferA.getBase() != transferB.getBase())
320 testDynamicValueUsingBounds);
330 for (
auto [posInDim, dimSize, offsetInDim] :
331 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333 if (posInDim < dimSize + offsetInDim)
337 posInDim = offsetInDim;
347 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
349 assert(constOp &&
"Unexpected non-constant index");
350 return constOp.value();
360 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
361 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
362 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
372 llvm::transform(foldResults, std::back_inserter(values),
374 if (
auto attr = dyn_cast<Attribute>(foldResult))
376 builder, loc, cast<IntegerAttr>(attr).getInt())
379 return cast<Value>(foldResult);
390 auto lhs = mul.getLhs();
391 auto rhs = mul.getRhs();
392 if (lhs.getDefiningOp<vector::VectorScaleOp>())
394 if (rhs.getDefiningOp<vector::VectorScaleOp>())
404 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
405 if (intAttr.getType() != expectedType)
454 void VectorDialect::initialize() {
456 #define GET_ATTRDEF_LIST
457 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
462 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
465 addInterfaces<VectorInlinerInterface>();
467 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
468 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
470 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
472 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
473 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
474 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
482 if (isa<ub::PoisonAttrInterface>(value))
485 return arith::ConstantOp::materialize(builder, value, type, loc);
501 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
504 CombiningKind
kind) {
508 reductionDims.push_back(en.index());
509 build(builder, result,
kind, source, acc, reductionDims);
512 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
514 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
519 std::optional<SmallVector<int64_t, 4>>
520 MultiDimReductionOp::getShapeForUnroll() {
521 return llvm::to_vector<4>(getSourceVectorType().
getShape());
527 Type inferredReturnType;
528 auto sourceScalableDims = getSourceVectorType().getScalableDims();
529 for (
auto [dimIdx, dimSize] :
531 if (!llvm::any_of(getReductionDims(),
532 [dimIdx = dimIdx](int64_t reductionDimIdx) {
533 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
535 targetShape.push_back(dimSize);
536 scalableDims.push_back(sourceScalableDims[dimIdx]);
539 if (targetShape.empty())
540 inferredReturnType = getSourceVectorType().getElementType();
543 targetShape, getSourceVectorType().
getElementType(), scalableDims);
544 if (
getType() != inferredReturnType)
545 return emitOpError() <<
"destination type " <<
getType()
546 <<
" is incompatible with source type "
547 << getSourceVectorType();
553 Type MultiDimReductionOp::getExpectedMaskType() {
554 auto vecType = getSourceVectorType();
557 vecType.getScalableDims());
566 struct ElideUnitDimsInMultiDimReduction
570 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
573 for (
const auto &dim :
enumerate(shape)) {
574 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
582 if (reductionOp.isMasked()) {
584 rootOp = reductionOp.getMaskingOp();
585 mask = reductionOp.getMaskingOp().getMask();
587 rootOp = reductionOp;
590 Location loc = reductionOp.getLoc();
591 Value acc = reductionOp.getAcc();
593 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
595 VectorType newMaskType =
597 dstVecType.getScalableDims());
598 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
600 cast = vector::ShapeCastOp::create(
601 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
606 mask = vector::ExtractOp::create(rewriter, loc, mask);
607 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
612 cast,
nullptr, mask);
619 void MultiDimReductionOp::getCanonicalizationPatterns(
621 results.
add<ElideUnitDimsInMultiDimReduction>(context);
630 arith::FastMathFlags fastMathFlags) {
631 build(builder, result,
kind, vector,
Value(), fastMathFlags);
636 arith::FastMathFlags fastMathFlags) {
637 build(builder, result,
638 llvm::cast<VectorType>(vector.
getType()).getElementType(),
kind, vector,
644 int64_t rank = getSourceVectorType().getRank();
646 return emitOpError(
"unsupported reduction rank: ") << rank;
649 Type eltType = getDest().getType();
651 return emitOpError(
"unsupported reduction type '")
652 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
661 Type ReductionOp::getExpectedMaskType() {
662 auto vecType = getSourceVectorType();
665 vecType.getScalableDims());
672 case arith::AtomicRMWKind::addf:
673 case arith::AtomicRMWKind::addi:
674 return vector::ReductionOp::create(builder, vector.
getLoc(),
675 CombiningKind::ADD, vector);
676 case arith::AtomicRMWKind::mulf:
677 case arith::AtomicRMWKind::muli:
678 return vector::ReductionOp::create(builder, vector.
getLoc(),
679 CombiningKind::MUL, vector);
680 case arith::AtomicRMWKind::minimumf:
681 return vector::ReductionOp::create(builder, vector.
getLoc(),
682 CombiningKind::MINIMUMF, vector);
683 case arith::AtomicRMWKind::mins:
684 return vector::ReductionOp::create(builder, vector.
getLoc(),
685 CombiningKind::MINSI, vector);
686 case arith::AtomicRMWKind::minu:
687 return vector::ReductionOp::create(builder, vector.
getLoc(),
689 case arith::AtomicRMWKind::maximumf:
690 return vector::ReductionOp::create(builder, vector.
getLoc(),
691 CombiningKind::MAXIMUMF, vector);
692 case arith::AtomicRMWKind::maxs:
693 return vector::ReductionOp::create(builder, vector.
getLoc(),
694 CombiningKind::MAXSI, vector);
695 case arith::AtomicRMWKind::maxu:
696 return vector::ReductionOp::create(builder, vector.
getLoc(),
697 CombiningKind::MAXUI, vector);
698 case arith::AtomicRMWKind::andi:
699 return vector::ReductionOp::create(builder, vector.
getLoc(),
700 CombiningKind::AND, vector);
701 case arith::AtomicRMWKind::ori:
702 return vector::ReductionOp::create(builder, vector.
getLoc(),
703 CombiningKind::OR, vector);
712 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
713 return llvm::to_vector<4>(getSourceVectorType().
getShape());
720 LogicalResult matchAndRewrite(ReductionOp reductionOp,
725 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
728 if (maskableOp.isMasked()) {
730 rootOp = maskableOp.getMaskingOp();
731 mask = maskableOp.getMaskingOp().getMask();
733 rootOp = reductionOp;
736 auto vectorType = reductionOp.getSourceVectorType();
737 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
740 Location loc = reductionOp.getLoc();
742 mask = ExtractOp::create(rewriter, loc, mask);
743 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
745 if (
Value acc = reductionOp.getAcc())
748 reductionOp.getFastmathAttr(), mask);
758 results.
add<ElideSingleElementReduction>(context);
772 getIndexingMapsAttrName(result.
name),
776 getIteratorTypesAttrName(result.
name),
779 return IteratorTypeAttr::get(builder.getContext(), t);
785 ArrayAttr indexingMaps,
786 ArrayAttr iteratorTypes) {
787 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
788 ContractionOp::getDefaultKind());
793 ArrayAttr indexingMaps,
794 ArrayAttr iteratorTypes, CombiningKind
kind) {
811 DictionaryAttr dictAttr;
826 dictAttr.getValue().end());
832 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
834 if (!iteratorTypes) {
836 <<
"expected " << getIteratorTypesAttrName(result.
name)
837 <<
" array attribute";
842 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
843 auto maybeIteratorType = symbolizeIteratorType(s);
844 if (!maybeIteratorType.has_value())
845 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
847 iteratorTypeAttrs.push_back(
855 getKindAttrName(result.
name),
857 ContractionOp::getDefaultKind()));
859 if (masksInfo.empty())
861 if (masksInfo.size() != 2)
863 "expected zero or exactly 2 vector mask operands");
864 auto lhsType = llvm::cast<VectorType>(types[0]);
865 auto rhsType = llvm::cast<VectorType>(types[1]);
867 std::array<VectorType, 2> maskTypes = {
877 auto attrNames = getTraitAttrNames();
879 traitAttrsSet.insert_range(attrNames);
881 for (
auto attr : (*this)->getAttrs()) {
882 if (attr.getName() == getIteratorTypesAttrName()) {
884 llvm::cast<ArrayAttr>(attr.getValue())
885 .getAsValueRange<IteratorTypeAttr, IteratorType>();
891 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
895 attrs.emplace_back(getIteratorTypesAttrName(),
897 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
898 attrs.push_back(attr);
902 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
903 p << getRhs() <<
", " << getAcc();
906 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
911 const std::vector<std::pair<int64_t, int64_t>> &map) {
912 for (
auto &dimPair : map) {
913 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
914 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
915 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
922 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
924 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
925 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
928 for (
auto &dimPair : contractingDimMap) {
929 lhsContractingDimSet.insert(dimPair.first);
930 rhsContractingDimSet.insert(dimPair.second);
933 llvm::make_second_range(batchDimMap));
937 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
938 if (lhsContractingDimSet.count(i) > 0)
940 expectedResultDims.push_back(lhsType.getDimSize(i));
944 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
945 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
947 expectedResultDims.push_back(rhsType.getDimSize(i));
951 if (expectedResultDims.empty()) {
953 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
954 return op.emitOpError(
"invalid accumulator/result vector shape");
957 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
958 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
959 if (!resVectorType || !accVectorType)
960 return op.emitOpError(
"invalid accumulator/result vector shape");
966 AffineMap lhsMap = op.getIndexingMapsArray()[0];
967 AffineMap rhsMap = op.getIndexingMapsArray()[1];
969 return op.emitOpError(
970 "expected all dimensions to be either a LHS or a RHS dimension");
973 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
974 VectorType v = pair.first;
975 auto map = pair.second;
976 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
977 unsigned pos = map.getDimPosition(idx);
982 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
983 return op.emitOpError(
"expected all dimensions to get an extent as "
984 "either a LHS or a RHS dimension");
986 AffineMap resMap = op.getIndexingMapsArray()[2];
992 llvm::IsaPred<AffineConstantExpr>) &&
993 "expected constant extent along all dimensions.");
995 auto expectedShape = llvm::to_vector<4>(
997 return cast<AffineConstantExpr>(e).getValue();
1001 resVectorType.getScalableDims());
1002 if (resVectorType != expected || accVectorType != expected)
1003 return op.emitOpError(
1004 "invalid accumulator/result vector shape, expected: ")
1011 VectorType lhsType = getLhsType();
1012 VectorType rhsType = getRhsType();
1013 Type accType = getAccType();
1014 Type resType = getResultType();
1016 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1017 if (!lhsType.getElementType().isSignlessInteger())
1018 return emitOpError(
"only supports signless integer types");
1022 if (getIndexingMapsArray().size() != 3)
1023 return emitOpError(
"expected an indexing map for each vector operand");
1028 unsigned numIterators = getIteratorTypes().getValue().size();
1030 auto index = it.index();
1031 auto map = it.value();
1032 if (map.getNumSymbols() != 0)
1033 return emitOpError(
"expected indexing map ")
1034 << index <<
" to have no symbols";
1035 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
1036 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1039 if (map.getNumDims() != numIterators)
1040 return emitOpError(
"expected indexing map ")
1041 << index <<
" to have " << numIterators <<
" number of inputs";
1042 if (map.getNumResults() != rank)
1043 return emitOpError(
"expected indexing map ")
1044 << index <<
" to have " << rank <<
" number of outputs";
1045 if (!map.isProjectedPermutation())
1046 return emitOpError(
"expected indexing map ")
1047 << index <<
" to be a projected permutation of its inputs";
1050 auto contractingDimMap = getContractingDimMap();
1051 auto batchDimMap = getBatchDimMap();
1054 if (contractingDimMap.empty())
1055 return emitOpError(
"expected at least one contracting dimension pair");
1058 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1059 return emitOpError(
"invalid contracting dimension map");
1063 return emitOpError(
"invalid batch dimension map");
1067 contractingDimMap, batchDimMap)))
1071 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1072 auto elementType = vectorType ? vectorType.getElementType() : resType;
1074 return emitOpError(
"unsupported contraction type");
1077 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1084 Type ContractionOp::getExpectedMaskType() {
1085 auto indexingMaps = this->getIndexingMapsArray();
1088 VectorType lhsType = this->getLhsType();
1089 VectorType rhsType = this->getRhsType();
1091 unsigned numVecDims = lhsIdxMap.
getNumDims();
1100 lhsType.getScalableDims()[dimIdx];
1105 rhsType.getScalableDims()[dimIdx];
1108 assert(ShapedType::isStaticShape(maskShape) &&
1109 "Mask shape couldn't be computed");
1113 maskShapeScalableDims);
1118 getIteratorTypesAttrName(), getKindAttrName()};
1128 static std::vector<std::pair<int64_t, int64_t>>
1130 IteratorType targetIteratorType,
MLIRContext *context) {
1131 std::vector<std::pair<int64_t, int64_t>> dimMap;
1133 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1134 if (iteratorType != targetIteratorType)
1140 if (lhsDim >= 0 && rhsDim >= 0)
1141 dimMap.emplace_back(lhsDim, rhsDim);
1146 void ContractionOp::getIterationBounds(
1148 auto lhsShape = getLhsType().getShape();
1149 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1154 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1155 if (iteratorType == IteratorType::reduction) {
1157 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1158 assert(lhsDimIndex >= 0);
1159 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1163 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1164 assert(resDimIndex >= 0);
1165 assert(resVectorType !=
nullptr);
1166 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1170 void ContractionOp::getIterationIndexMap(
1172 unsigned numMaps = getIndexingMapsArray().size();
1173 iterationIndexMap.resize(numMaps);
1175 auto index = it.index();
1176 auto map = it.value();
1177 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1178 auto dim = cast<AffineDimExpr>(map.getResult(i));
1179 iterationIndexMap[index][dim.getPosition()] = i;
1184 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1186 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1190 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1192 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1196 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1198 getIterationBounds(shape);
1220 template <
typename AddOpType>
1226 auto canonicalize = [&](
Value maybeContraction,
1227 Value otherOperand) -> vector::ContractionOp {
1228 vector::ContractionOp contractionOp =
1229 dyn_cast_or_null<vector::ContractionOp>(
1232 return vector::ContractionOp();
1233 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1234 contractionOp.getAcc().getDefiningOp())) {
1235 if (maybeZero.getValue() ==
1236 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1238 bvm.
map(contractionOp.getAcc(), otherOperand);
1239 auto newContraction =
1240 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1241 rewriter.
replaceOp(addOp, newContraction.getResult());
1242 return newContraction;
1245 return vector::ContractionOp();
1248 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1249 vector::ContractionOp
contract = canonicalize(a, b);
1251 return contract ? success() : failure();
1265 return index == poisonValue || (index >= 0 && index < maxIndex);
1274 setResultRanges(getResult(), argRanges.front());
1279 auto vectorTy = cast<VectorType>(source.
getType());
1284 Value source, int64_t position) {
1304 build(builder, result, source, dynamicPos,
1309 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1310 ExtractOp::Adaptor adaptor,
1312 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1313 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1314 vectorType.getRank()) {
1315 inferredReturnTypes.push_back(vectorType.getElementType());
1317 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1318 vectorType.getRank());
1320 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1321 vectorType.getScalableDims().drop_front(n)));
1329 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1330 return vectorType && vectorType.getShape().equals({1}) &&
1331 vectorType.getElementType() == r.front();
1333 if (l.size() == 1 && r.size() == 1 &&
1334 (isCompatible(l, r) || isCompatible(r, l)))
1340 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1341 if (resTy.getRank() == 0)
1343 "expected a scalar instead of a 0-d vector as the result type");
1346 auto dynamicMarkersCount =
1347 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1348 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1350 "mismatch between dynamic and static positions (kDynamic marker but no "
1351 "corresponding dynamic position) -- this can only happen due to an "
1352 "incorrect fold/rewrite");
1353 auto position = getMixedPosition();
1354 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1356 "expected position attribute of rank no greater than vector rank");
1358 if (
auto attr = dyn_cast<Attribute>(pos)) {
1359 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1361 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1362 return emitOpError(
"expected position attribute #")
1364 <<
" to be a non-negative integer smaller than the "
1365 "corresponding vector dimension or poison (-1)";
1372 template <
typename IntType>
1374 return llvm::to_vector<4>(llvm::map_range(
1375 arrayAttr.getAsRange<IntegerAttr>(),
1376 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1382 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1386 if (extractOp.hasDynamicPosition())
1390 ExtractOp currentOp = extractOp;
1392 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1393 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1396 if (currentOp.hasDynamicPosition())
1399 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1401 extractOp.setOperand(0, currentOp.getVector());
1404 std::reverse(globalPosition.begin(), globalPosition.end());
1405 extractOp.setStaticPosition(globalPosition);
1417 class ExtractFromInsertTransposeChainState {
1419 ExtractFromInsertTransposeChainState(ExtractOp e);
1428 template <
typename ContainerA,
typename ContainerB>
1429 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1430 return a.size() <= b.size() &&
1431 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1438 template <
typename ContainerA,
typename ContainerB>
1439 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1440 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1441 if (elemA < 0 || elemB < 0)
1456 void updateStateForNextIteration(
Value v) {
1463 LogicalResult handleTransposeOp();
1466 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1481 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1486 Value tryToFoldExtractOpInPlace(
Value source);
1488 ExtractOp extractOp;
1490 int64_t extractedRank;
1492 InsertOp nextInsertOp;
1493 TransposeOp nextTransposeOp;
1508 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1510 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1511 extractedRank(extractOp.getNumIndices()) {
1512 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1513 sentinels.reserve(vectorRank - extractedRank);
1514 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1515 sentinels.push_back(-(i + 1));
1517 extractOp.getStaticPosition().end());
1523 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1525 if (extractOp.hasDynamicPosition())
1528 if (!nextTransposeOp)
1531 nextTransposeOp.getPermutation(), extractOp.getContext()));
1538 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1541 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1548 res = nextInsertOp.getValueToStore();
1550 return success(canFold());
1557 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1559 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1572 res = nextInsertOp.getValueToStore();
1580 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1583 if (extractOp.hasDynamicPosition())
1587 bool nothingToFold = (source == extractOp.getVector());
1588 if (nothingToFold || !canFold())
1593 extractOp.setStaticPosition(
1595 extractOp.getVectorMutable().assign(source);
1596 return extractOp.getResult();
1600 Value ExtractFromInsertTransposeChainState::fold() {
1602 if (extractOp.hasDynamicPosition())
1605 Value valueToExtractFrom = extractOp.getVector();
1606 updateStateForNextIteration(valueToExtractFrom);
1607 while (nextInsertOp || nextTransposeOp) {
1610 if (succeeded(handleTransposeOp())) {
1611 valueToExtractFrom = nextTransposeOp.getVector();
1612 updateStateForNextIteration(valueToExtractFrom);
1618 if (succeeded(handleInsertOpWithMatchingPos(result)))
1623 if (succeeded(handleInsertOpWithPrefixPos(result)))
1624 return tryToFoldExtractOpInPlace(result);
1634 valueToExtractFrom = nextInsertOp.getDest();
1635 updateStateForNextIteration(valueToExtractFrom);
1638 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1643 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1644 auto vecType = dyn_cast<VectorType>(type);
1645 return vecType && vecType.getRank() == 0;
1655 if (isa<BroadcastOp, SplatOp>(op))
1658 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1666 VectorType srcType = shapeCast.getSourceVectorType();
1668 uint64_t srcRank = srcType.getRank();
1670 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1696 Operation *defOp = extractOp.getVector().getDefiningOp();
1703 if (extractOp.getType() == input.
getType())
1709 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1710 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1711 unsigned inputRank = inputType ? inputType.getRank() : 0;
1712 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1713 unsigned extractRank = extractType ? extractType.getRank() : 0;
1716 if (extractRank > inputRank)
1720 assert(inputType &&
"input must be a vector type because of previous checks");
1729 extractType.getShape() != inputShape.take_back(extractRank))
1734 unsigned deltaOverall = inputRank - extractRank;
1735 unsigned deltaBroadcast = broadcastRank - inputRank;
1739 for (
auto [i, size] :
llvm::enumerate(inputShape.take_front(deltaOverall))) {
1740 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1743 extractOp->setOperands(
1744 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1745 extractOp.setStaticPosition(staticPos);
1746 return extractOp.getResult();
1762 if (extractOp.hasDynamicPosition())
1765 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1770 if (shuffleOp.getResultVectorType().getRank() != 1)
1773 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1774 auto shuffleMask = shuffleOp.getMask();
1775 int64_t extractIdx = extractOp.getStaticPosition()[0];
1776 int64_t shuffleIdx = shuffleMask[extractIdx];
1779 if (shuffleIdx < inputVecSize) {
1780 extractOp.setOperand(0, shuffleOp.getV1());
1781 extractOp.setStaticPosition({shuffleIdx});
1783 extractOp.setOperand(0, shuffleOp.getV2());
1784 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1787 return extractOp.getResult();
1793 if (extractOp.hasDynamicPosition())
1796 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1801 auto getDimReverse = [](VectorType type, int64_t n) {
1802 return type.getShape().take_back(n + 1).front();
1804 int64_t destinationRank =
1805 llvm::isa<VectorType>(extractOp.getType())
1806 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1808 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1810 if (destinationRank > 0) {
1811 auto destinationType =
1812 llvm::cast<VectorType>(extractOp.getResult().getType());
1813 for (int64_t i = 0; i < destinationRank; i++) {
1817 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1818 getDimReverse(destinationType, i))
1825 std::reverse(extractedPos.begin(), extractedPos.end());
1828 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1829 strides.push_back(stride);
1831 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1834 int64_t position =
linearize(extractedPos, strides);
1838 int64_t numDimension =
1839 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1841 for (int64_t i = 0; i < numDimension; i++) {
1842 newStrides.push_back(stride);
1844 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1846 std::reverse(newStrides.begin(), newStrides.end());
1850 extractOp.setStaticPosition(newPosition);
1851 extractOp.setOperand(0, shapeCastOp.getSource());
1852 return extractOp.getResult();
1858 if (extractOp.hasDynamicPosition())
1861 auto extractStridedSliceOp =
1862 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1863 if (!extractStridedSliceOp)
1872 if (extractStridedSliceOp.hasNonUnitStrides())
1877 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1878 while (!sliceOffsets.empty()) {
1879 size_t lastOffset = sliceOffsets.size() - 1;
1880 if (sliceOffsets.back() != 0 ||
1881 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1882 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1884 sliceOffsets.pop_back();
1886 unsigned destinationRank = 0;
1887 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1888 destinationRank = vecType.getRank();
1891 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1892 sliceOffsets.size())
1896 assert(extractedPos.size() >= sliceOffsets.size());
1897 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1898 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1899 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1903 extractOp.setStaticPosition(extractedPos);
1904 return extractOp.getResult();
1910 if (extractOp.hasDynamicPosition())
1913 int64_t destinationRank =
1914 llvm::isa<VectorType>(extractOp.getType())
1915 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1917 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1927 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1928 insertOp.getSourceVectorType().getRank();
1929 if (destinationRank > insertOp.getSourceVectorType().getRank())
1931 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1934 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1935 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1938 bool disjoint =
false;
1940 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1941 int64_t start = insertOffsets[dim];
1943 (dim < insertRankDiff)
1945 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1946 int64_t end = start + size;
1947 int64_t offset = extractOffsets[dim];
1949 if (start <= offset && offset < end) {
1950 if (dim >= insertRankDiff)
1951 offsetDiffs.push_back(offset - start);
1961 int64_t srcRankDiff =
1962 insertOp.getSourceVectorType().getRank() - destinationRank;
1963 for (int64_t i = 0; i < destinationRank; i++) {
1964 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1965 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1969 extractOp.getVectorMutable().assign(insertOp.getValueToStore());
1972 extractOp.setStaticPosition(offsetDiffs);
1973 return extractOp.getResult();
1977 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1990 if (extractOp.hasDynamicPosition())
1994 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1995 if (!fromElementsOp)
1999 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2000 if (vecType.isScalable())
2004 int64_t rank = vecType.getRank();
2006 if (extractOp.getType() != vecType.getElementType())
2008 assert(
static_cast<int64_t
>(indices.size()) == rank &&
2009 "unexpected number of indices");
2014 for (
int i = rank - 1; i >= 0; --i) {
2015 flatIndex += indices[i] * stride;
2016 stride *= vecType.getDimSize(i);
2018 return fromElementsOp.getElements()[flatIndex];
2023 template <
typename OpType,
typename AdaptorType>
2026 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2027 OperandRange dynamicPosition = op.getDynamicPosition();
2030 if constexpr (std::is_same_v<OpType, ExtractOp>)
2031 vectorShape = op.getSourceVectorType().getShape();
2036 if (!dynamicPosition.size())
2043 bool opChange =
false;
2044 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2045 if (ShapedType::isStatic(staticPosition[i]))
2047 Attribute positionAttr = dynamicPositionAttr[index];
2048 Value position = dynamicPosition[index++];
2049 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2050 int64_t value = attr.getInt();
2054 staticPosition[i] = attr.getInt();
2059 operands.push_back(position);
2063 op.setStaticPosition(staticPosition);
2064 op.getOperation()->setOperands(operands);
2066 return op.getResult();
2075 int64_t poisonVal) {
2076 if (!is_contained(staticPos, poisonVal))
2084 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2093 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2098 if (denseAttr.isSplat()) {
2100 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2105 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2106 if (vecTy.isScalable())
2109 if (extractOp.hasDynamicPosition()) {
2124 copy(extractOp.getStaticPosition(), completePositions.begin());
2127 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2130 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2132 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2135 newAttr = *denseValuesBegin;
2145 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
2156 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2162 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2177 return inplaceFolded;
2190 Operation *defOp = extractOp.getVector().getDefiningOp();
2191 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2197 BroadcastableToResult::Success)
2213 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2217 VectorType extractedMaskType =
2218 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2220 if (!extractedMaskType)
2223 auto maskOperands = createMaskOp.getOperands();
2225 VectorType maskType = createMaskOp.getVectorType();
2227 bool containsUnknownDims =
false;
2230 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2232 int64_t pos = extractOpPos[dimIdx];
2233 Value operand = maskOperands[dimIdx];
2234 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2237 containsUnknownDims =
true;
2241 int64_t createMaskBound =
2242 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2244 if (pos != ShapedType::kDynamic) {
2247 allFalse |= pos >= createMaskBound;
2248 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2252 containsUnknownDims =
true;
2259 }
else if (!containsUnknownDims) {
2261 extractOp, extractedMaskType,
2262 maskOperands.drop_front(extractOpPos.size()));
2272 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2274 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2278 VectorType sourceType = castOp.getSourceVectorType();
2279 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2283 if (sourceType.getNumElements() != targetType.getNumElements())
2287 castOp.getSource());
2297 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2300 if (extractOp.hasDynamicPosition())
2304 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2309 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2310 if (!fromElementsOp)
2312 VectorType inputType = fromElementsOp.getType();
2315 if (resultType.isScalable() || inputType.isScalable())
2321 llvm::to_vector(extractOp.getStaticPosition());
2322 firstElementPos.append(resultType.getRank(), 0);
2325 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2326 flatIndex += firstElementPos[i] * stride;
2327 stride *= inputType.getDimSize(i);
2332 extractOp, resultType,
2333 fromElementsOp.getElements().slice(flatIndex,
2334 resultType.getNumElements()));
2342 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2343 results.
add(foldExtractFromShapeCastToShapeCast);
2344 results.
add(foldExtractFromFromElements);
2349 for (
auto attr : arrayAttr)
2350 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2357 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2368 if (operands.empty())
2371 return llvm::all_of(operands, [&](
Value operand) {
2373 return currentDef == defOp;
2388 static LogicalResult
2391 auto fromElementsOp =
2392 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2393 if (!fromElementsOp)
2396 llvm::append_range(results, fromElementsOp.getElements());
2400 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2406 ToElementsOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2407 ToElementsOp::Adaptor adaptor,
2409 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2410 Type elType = vecType.getElementType();
2411 inferredReturnTypes.append(vecType.getNumElements(), elType);
2432 OperandRange fromElemsOperands = fromElementsOp.getElements();
2433 if (fromElemsOperands.empty())
2436 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2444 Value toElementsInput = toElementsOp.getSource();
2445 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2446 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2447 return toElementsInput;
2466 if (llvm::any_of(elements, [](
Attribute attr) {
return !attr; }))
2470 auto destVecType = fromElementsOp.getDest().getType();
2471 auto destEltType = destVecType.getElementType();
2472 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2477 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2484 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2498 static LogicalResult
2501 if (!llvm::all_equal(fromElementsOp.getElements()))
2504 fromElementsOp, fromElementsOp.getType(),
2505 fromElementsOp.getElements().front());
2531 using OpRewritePattern::OpRewritePattern;
2533 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2537 if (fromElements.getType().getNumElements() == 1)
2548 for (
auto [insertIndex, element] :
2552 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2555 "element not from vector.extract");
2560 if (insertIndex == 0) {
2561 source = extractOp.getVector();
2562 }
else if (extractOp.getVector() != source) {
2564 "element from different vector");
2568 int64_t rank = position.size();
2569 assert(rank == source.getType().getRank() &&
2570 "scalar extract must have full rank position");
2581 if (insertIndex == 0) {
2582 const int64_t numElms = fromElements.getType().getNumElements();
2583 int64_t numSuffixElms = 1;
2584 int64_t index = rank;
2585 while (index > 0 && position[index - 1] == 0 &&
2586 numSuffixElms < numElms) {
2587 numSuffixElms *= source.getType().getDimSize(index - 1);
2590 if (numSuffixElms != numElms) {
2592 fromElements,
"elements do not form a suffix of source");
2594 expectedPosition = llvm::to_vector(position);
2595 combinedPosition = position.drop_back(rank - index);
2599 else if (expectedPosition != position) {
2601 fromElements,
"elements not in ascending order (static order)");
2603 increment(expectedPosition, source.getType().getShape());
2606 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2607 fromElements.getLoc(), source, combinedPosition);
2610 fromElements, fromElements.getType(), extracted);
2618 for (
int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2620 if (indices[dim] < shape[dim])
2639 setResultRanges(getResult(), argRanges.front());
2642 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2643 return llvm::to_vector<4>(getResultVectorType().
getShape());
2651 int64_t rankDiff = dstShape.size() - srcShape.size();
2652 int64_t dstDim = rankDiff;
2654 for (
auto [s1, s2] :
2655 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2657 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2667 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2686 Value BroadcastOp::createOrFoldBroadcastOp(
2689 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2693 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2694 if (broadcastedDims.contains(i))
2696 checkShape.push_back(dstShape[i]);
2698 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2699 "ill-formed broadcastedDims contains values not confined to "
2704 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2708 if (!srcVectorType) {
2709 assert(checkShape.empty() &&
2710 "ill-formed createOrFoldBroadcastOp arguments");
2711 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2714 assert(srcVectorType.getShape().equals(checkShape) &&
2715 "ill-formed createOrFoldBroadcastOp arguments");
2726 broadcastShape.reserve(dstShape.size());
2742 int64_t nextSrcShapeDim = broadcastedDims.size();
2743 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2744 if (broadcastedDims.contains(i)) {
2749 broadcastShape.push_back(dstShape[i]);
2750 permutation[i] = broadcastShape.size() - 1;
2756 permutation[i] = nextSrcShapeDim++;
2760 llvm::append_range(broadcastShape, srcVectorType.getShape());
2765 "unexpected \"dim-1\" broadcast");
2767 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2769 vector::BroadcastableToResult::Success &&
2770 "must be broadcastable");
2774 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2775 if (permutation[i] != i)
2776 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2782 Type srcType, VectorType dstVectorType,
2783 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2785 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2787 return BroadcastableToResult::Success;
2789 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2791 return BroadcastableToResult::SourceTypeNotAVector;
2793 int64_t srcRank = srcVectorType.getRank();
2794 int64_t dstRank = dstVectorType.getRank();
2795 if (srcRank > dstRank)
2796 return BroadcastableToResult::SourceRankHigher;
2799 int64_t lead = dstRank - srcRank;
2800 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2803 bool foundMismatchingDims =
false;
2806 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2807 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2808 if (srcDim != 1 && srcDim != dstDim)
2809 foundMismatchingDims =
true;
2812 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2813 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2814 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2817 (srcDimScalableFlag != dstDimScalableFlag &&
2818 (srcDim != 1 || srcDimScalableFlag)))
2819 foundMismatchingDims =
true;
2821 if (foundMismatchingDims) {
2822 if (mismatchingDims !=
nullptr) {
2823 mismatchingDims->first.dim = srcDim;
2824 mismatchingDims->first.isScalable = srcDimScalableFlag;
2826 mismatchingDims->second.dim = dstDim;
2827 mismatchingDims->second.isScalable = dstDimScalableFlag;
2829 return BroadcastableToResult::DimensionMismatch;
2833 return BroadcastableToResult::Success;
2837 std::pair<VectorDim, VectorDim> mismatchingDims;
2839 getSourceType(), getResultVectorType(), &mismatchingDims);
2840 if (res == BroadcastableToResult::Success)
2842 if (res == BroadcastableToResult::SourceRankHigher)
2843 return emitOpError(
"source rank higher than destination rank");
2844 if (res == BroadcastableToResult::DimensionMismatch) {
2845 return emitOpError(
"dimension mismatch (")
2846 << (mismatchingDims.first.isScalable ?
"[" :
"")
2847 << mismatchingDims.first.dim
2848 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2849 << (mismatchingDims.second.isScalable ?
"[" :
"")
2850 << mismatchingDims.second.dim
2851 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2853 if (res == BroadcastableToResult::SourceTypeNotAVector)
2854 return emitOpError(
"source type is not a vector");
2855 llvm_unreachable(
"unexpected vector.broadcast op error");
2862 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2866 VectorType srcType = srcShapeCast.getSourceVectorType();
2867 VectorType destType = broadcastOp.getResultVectorType();
2870 BroadcastableToResult::Success)
2875 srcShapeCast.getResultVectorType().getShape();
2878 unsigned numTrailingDims =
std::min(srcShape.size(), shapecastShape.size());
2879 if (!llvm::equal(srcShape.take_back(numTrailingDims),
2880 shapecastShape.take_back(numTrailingDims)))
2883 assert(all_of(srcShape.drop_back(numTrailingDims),
2884 [](int64_t E) { return E == 1; }) &&
2885 all_of(shapecastShape.drop_back(numTrailingDims),
2886 [](int64_t E) { return E == 1; }) &&
2887 "ill-formed shape_cast");
2889 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
2894 if (getSourceType() == getResultVectorType())
2899 if (!adaptor.getSource())
2901 auto vectorType = getResultVectorType();
2902 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2903 if (vectorType.getElementType() != attr.getType())
2907 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2908 if (vectorType.getElementType() != attr.getType())
2912 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2914 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2927 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2931 broadcastOp.getResultVectorType(),
2932 srcBroadcast.getSource());
2942 results.
add<BroadcastFolder>(context);
2950 VectorType resultType = getResultVectorType();
2951 VectorType v1Type = getV1VectorType();
2952 VectorType v2Type = getV2VectorType();
2954 int64_t resRank = resultType.getRank();
2955 int64_t v1Rank = v1Type.getRank();
2956 int64_t v2Rank = v2Type.getRank();
2957 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2958 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2959 if (!wellFormed0DCase && !wellFormedNDCase)
2960 return emitOpError(
"rank mismatch");
2963 for (int64_t r = 1; r < v1Rank; ++r) {
2964 int64_t resDim = resultType.getDimSize(r);
2965 int64_t v1Dim = v1Type.getDimSize(r);
2966 int64_t v2Dim = v2Type.getDimSize(r);
2967 if (resDim != v1Dim || v1Dim != v2Dim)
2968 return emitOpError(
"dimension mismatch");
2972 int64_t maskLength = mask.size();
2973 if (maskLength <= 0)
2974 return emitOpError(
"invalid mask length");
2975 if (maskLength != resultType.getDimSize(0))
2976 return emitOpError(
"mask length mismatch");
2978 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2979 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2982 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2988 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2989 ShuffleOp::Adaptor adaptor,
2991 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2992 auto v1Rank = v1Type.getRank();
2996 shape.reserve(v1Rank);
2997 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3000 llvm::append_range(shape, v1Type.getShape().drop_front());
3001 inferredReturnTypes.push_back(
3006 template <
typename T>
3009 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3010 return value == expected++;
3014 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3015 auto v1Type = getV1VectorType();
3016 auto v2Type = getV2VectorType();
3018 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3019 "Vector shuffle does not support scalable vectors");
3023 if (v1Type.getRank() == 0)
3027 auto mask = getMask();
3034 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3035 if (!v1Attr || !v2Attr)
3039 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3040 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3041 if (isV1Poison && isV2Poison)
3046 if (v1Type.getRank() != 1)
3055 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3058 v2Elements = to_vector(v2DenseAttr.getValues<
Attribute>());
3059 poisonElement = v2Elements[0];
3062 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3065 v1Elements = to_vector(v1DenseAttr.getValues<
Attribute>());
3066 poisonElement = v1Elements[0];
3070 int64_t v1Size = v1Type.getDimSize(0);
3071 for (int64_t maskIdx : mask) {
3074 if (maskIdx == ShuffleOp::kPoisonIndex) {
3075 indexedElm = poisonElement;
3077 if (maskIdx < v1Size)
3078 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3080 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3083 results.push_back(indexedElm);
3098 VectorType v1VectorType = shuffleOp.getV1VectorType();
3100 if (v1VectorType.getRank() > 0)
3102 if (mask.size() != 1)
3123 static Value getScalarSplatSource(
Value value) {
3130 if (
auto splat = dyn_cast<vector::SplatOp>(defOp))
3131 return splat.getInput();
3133 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3140 if (isa<VectorType>(
broadcast.getSourceType()))
3154 Value splat = getScalarSplatSource(op.getV1());
3155 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3171 VectorType resultType = op.getResultVectorType();
3172 if (resultType.isScalable())
3174 op,
"ShuffleOp can't represent a scalable interleave");
3176 if (resultType.getRank() != 1)
3178 op,
"ShuffleOp can't represent an n-D interleave");
3180 VectorType sourceType = op.getV1VectorType();
3181 if (sourceType != op.getV2VectorType() ||
3182 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3184 op,
"ShuffleOp types don't match an interleave");
3188 int64_t resultVectorSize = resultType.getNumElements();
3189 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3190 int64_t maskValueA = shuffleMask[i * 2];
3191 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3192 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3194 "ShuffleOp mask not interleaving");
3206 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3216 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3221 auto vectorTy = cast<VectorType>(dest.
getType());
3222 build(builder, result, source, dest,
3227 Value source,
Value dest, int64_t position) {
3240 posVals.reserve(position.size());
3241 llvm::transform(position, std::back_inserter(posVals),
3243 build(builder, result, source, dest, posVals);
3252 build(builder, result, source, dest, dynamicPos,
3257 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3258 if (srcTy.getRank() == 0)
3260 "expected a scalar instead of a 0-d vector as the source operand");
3263 auto destVectorType = getDestVectorType();
3264 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3266 "expected position attribute of rank no greater than dest vector rank");
3267 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3268 if (srcVectorType &&
3269 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3270 static_cast<unsigned>(destVectorType.getRank())))
3271 return emitOpError(
"expected position attribute rank + source rank to "
3272 "match dest vector rank");
3273 if (!srcVectorType &&
3274 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3276 "expected position attribute rank to match the dest vector rank");
3278 if (
auto attr = dyn_cast<Attribute>(pos)) {
3279 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3281 destVectorType.getDimSize(idx))) {
3282 return emitOpError(
"expected position attribute #")
3284 <<
" to be a non-negative integer smaller than the "
3286 "dest vector dimension";
3299 assert(positions.size() <= completePositions.size() &&
3300 "positions size must be less than or equal to destTy rank");
3301 copy(positions, completePositions.begin());
3316 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3317 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3318 srcVecType.getNumElements())
3321 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3334 Value splat = getScalarSplatSource(op.getValueToStore());
3335 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3363 class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3369 VectorType destTy = op.getDestVectorType();
3370 if (destTy.isScalable())
3374 if (
auto insertOp = dyn_cast<InsertOp>(user))
3375 if (insertOp.getDest() == op.getResult())
3378 InsertOp currentOp = op;
3382 if (currentOp.hasDynamicPosition())
3385 chainInsertOps.push_back(currentOp);
3386 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3389 if (currentOp && !currentOp->hasOneUse())
3393 int64_t vectorSize = destTy.getNumElements();
3394 int64_t initializedCount = 0;
3400 for (
auto insertOp : chainInsertOps) {
3402 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3406 int64_t insertBeginPosition =
3411 int64_t insertSize = 1;
3412 if (
auto srcVectorType =
3413 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3414 insertSize = srcVectorType.getNumElements();
3416 assert(insertBeginPosition + insertSize <= vectorSize &&
3417 "insert would overflow the vector");
3419 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3420 insertBeginPosition + insertSize)) {
3421 if (initializedDestIdxs[index])
3423 initializedDestIdxs[index] =
true;
3429 pendingInsertPos.push_back(insertBeginPosition);
3430 pendingInsertSize.push_back(insertSize);
3431 pendingInsertValues.push_back(insertOp.getValueToStore());
3433 if (initializedCount == vectorSize)
3438 if (initializedCount != vectorSize)
3442 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3443 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3444 pendingInsertValues))) {
3445 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3447 if (!srcVectorType) {
3448 elements[insertBeginPosition] = valueToStore;
3453 srcVectorType.getElementType());
3455 auto elementsToInsert = rewriter.
create<vector::ToElementsOp>(
3456 op.getLoc(), elementToInsertTypes, valueToStore);
3457 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3458 elements[insertBeginPosition + linearIdx] =
3473 int64_t maxVectorSizeFoldThreshold) {
3474 if (insertOp.hasDynamicPosition())
3477 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3485 VectorType destTy = insertOp.getDestVectorType();
3486 if (destTy.isScalable())
3490 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3491 !insertOp->hasOneUse())
3495 int64_t insertBeginPosition =
3498 Type destEltType = destTy.getElementType();
3502 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3503 for (
auto value : denseSource.getValues<
Attribute>())
3509 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3510 copy(insertedValues, allValues.begin() + insertBeginPosition);
3519 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3523 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3526 insertOp.
setOperand(1, destInsert.getDest());
3527 return insertOp.getResult();
3532 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3533 InsertChainFullyInitialized>(context);
3539 constexpr int64_t vectorSizeFoldThreshold = 256;
3543 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3544 return getValueToStore();
3554 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3557 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3558 vectorSizeFoldThreshold)) {
3562 return inplaceFolded;
3584 template <
typename OpType>
3586 ArrayAttr arrayAttr,
3588 StringRef attrName) {
3589 if (arrayAttr.size() > shape.size())
3590 return op.emitOpError(
"expected ")
3591 << attrName <<
" attribute of rank no greater than vector rank";
3598 template <
typename OpType>
3599 static LogicalResult
3601 int64_t
max, StringRef attrName,
3602 bool halfOpen =
true) {
3603 for (
auto attr : arrayAttr) {
3604 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3608 if (val < min || val >= upper)
3609 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3610 <<
min <<
", " << upper <<
")";
3618 template <
typename OpType>
3619 static LogicalResult
3622 bool halfOpen =
true, int64_t
min = 0) {
3623 for (
auto [index, attrDimPair] :
3625 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3626 int64_t
max = std::get<1>(attrDimPair);
3629 if (val < min || val >=
max)
3630 return op.emitOpError(
"expected ")
3631 << attrName <<
" dimension " << index <<
" to be confined to ["
3632 <<
min <<
", " <<
max <<
")";
3642 template <
typename OpType>
3644 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3646 bool halfOpen =
true, int64_t
min = 1) {
3647 assert(arrayAttr1.size() <= shape.size());
3648 assert(arrayAttr2.size() <= shape.size());
3649 for (
auto [index, it] :
3651 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3652 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3653 int64_t
max = std::get<2>(it);
3656 if (val1 + val2 < 0 || val1 + val2 >=
max)
3657 return op.emitOpError(
"expected sum(")
3658 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3659 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3666 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3673 auto sourceVectorType = getSourceVectorType();
3674 auto destVectorType = getDestVectorType();
3675 auto offsets = getOffsetsAttr();
3676 auto strides = getStridesAttr();
3677 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3679 "expected offsets of same size as destination vector rank");
3680 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3681 return emitOpError(
"expected strides of same size as source vector rank");
3682 if (sourceVectorType.getRank() > destVectorType.getRank())
3684 "expected source rank to be no greater than destination rank");
3686 auto sourceShape = sourceVectorType.getShape();
3687 auto destShape = destVectorType.getShape();
3689 destShape.size() - sourceShape.size(), 0);
3690 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3691 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3692 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3701 offName,
"source vector shape",
3705 unsigned rankDiff = destShape.size() - sourceShape.size();
3706 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3707 if (sourceVectorType.getScalableDims()[idx] !=
3708 destVectorType.getScalableDims()[idx + rankDiff]) {
3709 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3712 if (sourceVectorType.getScalableDims()[idx]) {
3713 auto sourceSize = sourceShape[idx];
3714 auto destSize = destShape[idx + rankDiff];
3715 if (sourceSize != destSize) {
3716 return emitOpError(
"expected size at idx=")
3718 << (
" to match the corresponding base size from the input "
3720 << sourceSize << (
" vs ") << destSize << (
")");
3730 class FoldInsertStridedSliceSplat final
3735 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3738 auto dst = insertStridedSliceOp.getDest();
3739 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3740 if (!splat || getScalarSplatSource(dst) != splat)
3743 rewriter.
replaceOp(insertStridedSliceOp, dst);
3750 class FoldInsertStridedSliceOfExtract final
3755 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3757 auto extractStridedSliceOp =
3758 insertStridedSliceOp.getValueToStore()
3759 .getDefiningOp<vector::ExtractStridedSliceOp>();
3761 if (!extractStridedSliceOp)
3764 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3768 if (extractStridedSliceOp.getStrides() !=
3769 insertStridedSliceOp.getStrides() ||
3770 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3773 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3780 class InsertStridedSliceConstantFolder final
3787 static constexpr int64_t vectorSizeFoldThreshold = 256;
3798 VectorType destTy = destVector.getType();
3799 if (destTy.isScalable())
3803 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3804 !destVector.hasOneUse())
3813 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3817 if (op.hasNonUnitStrides())
3820 VectorType sliceVecTy = sourceValue.getType();
3822 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3832 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3833 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3834 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3835 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3838 currDestPosition.begin() + rankDifference, currDestPosition.end());
3842 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3843 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3844 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3845 "Invalid slice element");
3846 newValues[linearizedPosition] = *sliceValuesIt;
3859 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3861 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3862 InsertStridedSliceConstantFolder>(context);
3865 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3866 if (getSourceVectorType() == getDestVectorType())
3867 return getValueToStore();
3883 p <<
" " << getLhs() <<
", " << getRhs();
3885 p <<
", " << getAcc();
3888 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3899 if (operandsInfo.size() < 2)
3901 "expected at least 2 operands");
3902 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3903 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3906 "expected vector type for operand #1");
3911 vRHS.getScalableDims()[0]};
3913 vLHS.getElementType(), scalableDimsRes);
3917 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3923 OuterProductOp::getKindAttrName(result.
name),
3925 OuterProductOp::getDefaultKind()));
3931 (operandsInfo.size() > 2 &&
3937 Type tRHS = getOperandTypeRHS();
3938 VectorType vLHS = getOperandVectorTypeLHS(),
3939 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3940 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3942 if (vLHS.getRank() != 1)
3943 return emitOpError(
"expected 1-d vector for operand #1");
3947 if (vRHS.getRank() != 1)
3948 return emitOpError(
"expected 1-d vector for operand #2");
3949 if (vRES.getRank() != 2)
3950 return emitOpError(
"expected 2-d vector result");
3951 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3952 return emitOpError(
"expected #1 operand dim to match result dim #1");
3953 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3954 return emitOpError(
"expected #2 operand dim to match result dim #2");
3955 if (vLHS.isScalable() && !vRHS.isScalable()) {
3959 "expected either both or only #2 operand dim to be scalable");
3963 if (vRES.getRank() != 1)
3964 return emitOpError(
"expected 1-d vector result");
3965 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3966 return emitOpError(
"expected #1 operand dim to match result dim #1");
3969 if (vACC && vACC != vRES)
3970 return emitOpError(
"expected operand #3 of same type as result type");
3974 return emitOpError(
"unsupported outerproduct type");
3983 Type OuterProductOp::getExpectedMaskType() {
3984 auto vecType = this->getResultVectorType();
3987 vecType.getScalableDims());
3999 ArrayAttr offsets, ArrayAttr sizes,
4000 ArrayAttr strides) {
4001 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4003 shape.reserve(vectorType.getRank());
4005 for (
unsigned e = offsets.size(); idx < e; ++idx)
4006 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4007 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4008 shape.push_back(vectorType.getShape()[idx]);
4011 vectorType.getScalableDims());
4024 offsetsAttr, sizesAttr, stridesAttr));
4025 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
4029 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
4034 auto type = getSourceVectorType();
4035 auto offsets = getOffsetsAttr();
4036 auto sizes = getSizesAttr();
4037 auto strides = getStridesAttr();
4038 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4040 "expected offsets, sizes and strides attributes of same size");
4042 auto shape = type.getShape();
4043 auto offName = getOffsetsAttrName();
4044 auto sizesName = getSizesAttrName();
4045 auto stridesName = getStridesAttrName();
4061 shape, offName, sizesName,
4066 offsets, sizes, strides);
4067 if (getResult().
getType() != resultType)
4068 return emitOpError(
"expected result type to be ") << resultType;
4070 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4071 if (type.getScalableDims()[idx]) {
4072 auto inputDim = type.getShape()[idx];
4073 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4074 if (inputDim != inputSize)
4075 return emitOpError(
"expected size at idx=")
4077 << (
" to match the corresponding base size from the input "
4079 << inputSize << (
" vs ") << inputDim << (
")");
4089 static LogicalResult
4092 auto getElement = [](ArrayAttr array,
int idx) {
4093 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4095 ArrayAttr extractOffsets = op.getOffsets();
4097 ArrayAttr extractSizes = op.getSizes();
4098 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
4100 if (op.getSourceVectorType().getRank() !=
4101 insertOp.getSourceVectorType().getRank())
4103 ArrayAttr insertOffsets = insertOp.getOffsets();
4104 ArrayAttr insertStrides = insertOp.getStrides();
4107 if (extractOffsets.size() > insertOffsets.size())
4109 bool patialoverlap =
false;
4110 bool disjoint =
false;
4112 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4113 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4115 int64_t start = getElement(insertOffsets, dim);
4116 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4117 int64_t offset = getElement(extractOffsets, dim);
4118 int64_t size = getElement(extractSizes, dim);
4120 if (start <= offset && offset < end) {
4123 if (offset + size > end)
4124 patialoverlap =
true;
4125 offsetDiffs.push_back(offset - start);
4132 if (!disjoint && !patialoverlap) {
4133 op.setOperand(insertOp.getValueToStore());
4142 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4157 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4162 if (op.hasNonUnitStrides())
4165 VectorType sourceVecTy = op.getSourceVectorType();
4169 VectorType sliceVecTy = op.getType();
4171 int64_t rank = sliceVecTy.getRank();
4183 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4185 sliceValues.reserve(sliceVecTy.getNumElements());
4188 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
4189 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4191 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4192 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4194 assert(
static_cast<int64_t
>(sliceValues.size()) ==
4195 sliceVecTy.getNumElements() &&
4196 "Invalid number of slice elements");
4200 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4201 if (getSourceVectorType() == getResult().
getType())
4208 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
4233 class StridedSliceCreateMaskFolder final
4238 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4240 Location loc = extractStridedSliceOp.getLoc();
4244 extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
4248 if (extractStridedSliceOp.hasNonUnitStrides())
4261 sliceMaskDimSizes.reserve(maskDimSizes.size());
4265 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4266 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4270 IntegerAttr offsetAttr =
4272 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4273 Value sliceMaskDimSize =
4274 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4275 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4280 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4284 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4292 class StridedSliceConstantMaskFolder final
4297 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4301 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
4302 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4303 if (!constantMaskOp)
4306 if (extractStridedSliceOp.hasNonUnitStrides())
4319 sliceMaskDimSizes.reserve(maskDimSizes.size());
4320 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4321 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4322 int64_t sliceMaskDimSize =
std::max(
4323 static_cast<int64_t
>(0),
4324 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4325 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4328 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4329 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4330 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4333 if (llvm::is_contained(sliceMaskDimSizes, 0))
4334 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4339 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4347 class StridedSliceBroadcast final
4359 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4360 auto dstVecType = llvm::cast<VectorType>(op.getType());
4361 unsigned dstRank = dstVecType.getRank();
4362 unsigned rankDiff = dstRank - srcRank;
4366 bool needsSlice =
false;
4367 for (
unsigned i = 0; i < srcRank; i++) {
4368 if (srcVecType.getDimSize(i) != 1 &&
4369 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4380 for (
unsigned i = 0; i < srcRank; i++) {
4381 if (srcVecType.getDimSize(i) == 1) {
4389 source = ExtractStridedSliceOp::create(
4390 rewriter, op->getLoc(), source, offsets, sizes,
4399 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4406 Value splat = getScalarSplatSource(op.getVector());
4430 class ContiguousExtractStridedSliceToExtract final
4437 if (op.hasNonUnitStrides())
4439 Value source = op.getOperand();
4440 auto sourceType = cast<VectorType>(source.
getType());
4441 if (sourceType.isScalable() || sourceType.getRank() == 0)
4450 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4451 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4458 if (numOffsets == 0)
4463 if (numOffsets == sourceType.getRank() &&
4464 static_cast<int>(sizes.size()) == sourceType.getRank())
4468 for (
int i = 0; i < numOffsets; ++i) {
4476 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4477 sizes[numOffsets] == 1) {
4482 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4483 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4492 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4496 results.
add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4497 StridedSliceBroadcast, StridedSliceSplat,
4498 ContiguousExtractStridedSliceToExtract>(context);
4507 VectorType vectorType,
Value source,
4508 ValueRange indices, std::optional<Value> padding,
4509 AffineMapAttr permutationMapAttr,
4510 ArrayAttr inBoundsAttr) {
4512 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4514 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4515 build(builder, result, vectorType, source, indices, permutationMapAttr,
4516 *padding,
Value(), inBoundsAttr);
4521 VectorType vectorType,
Value source,
4522 ValueRange indices, std::optional<Value> padding,
4526 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4530 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4532 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4533 build(builder, result, vectorType, source, indices, *padding,
4534 permutationMapAttr, inBoundsAttr);
4539 VectorType vectorType,
Value source,
4540 ValueRange indices, std::optional<Value> padding,
4543 llvm::cast<ShapedType>(source.
getType()), vectorType);
4545 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4549 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4551 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4552 build(builder, result, vectorType, source, indices, permutationMapAttr,
4554 Value(), inBoundsAttr);
4557 template <
typename EmitFun>
4559 EmitFun emitOpError) {
4561 for (
auto expr : permutationMap.
getResults()) {
4562 auto dim = dyn_cast<AffineDimExpr>(expr);
4563 auto zero = dyn_cast<AffineConstantExpr>(expr);
4565 if (zero.getValue() != 0) {
4567 "requires a projected permutation_map (at most one dim or the zero "
4568 "constant can appear in each result)");
4573 return emitOpError(
"requires a projected permutation_map (at most one "
4574 "dim or the zero constant can appear in each result)");
4576 if (seen[dim.getPosition()]) {
4578 "requires a permutation_map that is a permutation (found one dim "
4579 "used more than once)");
4581 seen[dim.getPosition()] =
true;
4586 static LogicalResult
4588 VectorType vectorType, VectorType maskType,
4589 VectorType inferredMaskType,
AffineMap permutationMap,
4590 ArrayAttr inBounds) {
4591 if (op->hasAttr(
"masked")) {
4592 return op->emitOpError(
"masked attribute has been removed. "
4593 "Use in_bounds instead.");
4596 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4597 return op->emitOpError(
4598 "requires source to be a memref or ranked tensor type");
4600 auto elementType = shapedType.getElementType();
4601 DataLayout dataLayout = DataLayout::closest(op);
4602 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4604 unsigned sourceVecSize =
4606 vectorElementType.getShape().back();
4607 unsigned resultVecSize =
4609 vectorType.getShape().back();
4610 if (resultVecSize % sourceVecSize != 0)
4611 return op->emitOpError(
4612 "requires the bitwidth of the minor 1-D vector to be an integral "
4613 "multiple of the bitwidth of the minor 1-D vector of the source");
4615 unsigned sourceVecEltRank = vectorElementType.getRank();
4616 unsigned resultVecRank = vectorType.getRank();
4617 if (sourceVecEltRank > resultVecRank)
4618 return op->emitOpError(
4619 "requires source vector element and vector result ranks to match.");
4620 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4623 return op->emitOpError(
"requires a permutation_map with result dims of "
4624 "the same rank as the vector type");
4627 return op->emitOpError(
"does not support masks with vector element type");
4630 unsigned minorSize =
4631 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4632 unsigned resultVecSize =
4635 return op->emitOpError(
4636 "requires the bitwidth of the minor 1-D vector to be an integral "
4637 "multiple of the bitwidth of the source element type");
4641 return op->emitOpError(
"requires a permutation_map with result dims of "
4642 "the same rank as the vector type");
4646 return op->emitOpError(
"requires permutation_map without symbols");
4648 if (permutationMap.
getNumInputs() != shapedType.getRank())
4649 return op->emitOpError(
"requires a permutation_map with input dims of the "
4650 "same rank as the source type");
4652 if (maskType && maskType != inferredMaskType)
4653 return op->emitOpError(
"inferred mask type (")
4654 << inferredMaskType <<
") and mask operand type (" << maskType
4657 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4658 return op->emitOpError(
"expects the in_bounds attr of same rank "
4659 "as permutation_map results: ")
4661 <<
" vs inBounds of size: " << inBounds.size();
4668 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4669 if (op.getPermutationMap().isMinorIdentity())
4670 elidedAttrs.push_back(op.getPermutationMapAttrName());
4672 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4673 elidedAttrs.push_back(op.getInBoundsAttrName());
4680 p <<
", " << getMask();
4689 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4694 if (maskShape.empty())
4695 maskShape.push_back(1);
4717 if (hasMask.succeeded()) {
4724 if (types.size() != 2)
4725 return parser.
emitError(typesLoc,
"requires two types");
4727 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4728 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4729 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4730 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4732 return parser.
emitError(typesLoc,
"requires vector type");
4733 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4737 if (shapedType.getRank() <
4740 "expected a custom permutation_map when "
4741 "rank(source) != rank(destination)");
4745 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4747 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4749 if (!inBoundsAttr) {
4759 if (hasMask.succeeded()) {
4760 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4762 maskInfo.
location,
"does not support masks with vector element type");
4765 "expected the same rank for the vector and the "
4766 "results of the permutation map");
4774 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4776 {1, static_cast<int32_t>(indexInfo.size()), 1,
4777 static_cast<int32_t>(hasMask.succeeded())}));
4783 ShapedType shapedType = getShapedType();
4785 VectorType maskType = getMaskType();
4786 auto paddingType = getPadding().getType();
4787 auto permutationMap = getPermutationMap();
4788 VectorType inferredMaskType =
4791 auto sourceElementType = shapedType.getElementType();
4793 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4794 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4797 shapedType, vectorType, maskType,
4798 inferredMaskType, permutationMap, getInBounds())))
4801 if (
auto sourceVectorElementType =
4802 llvm::dyn_cast<VectorType>(sourceElementType)) {
4805 if (sourceVectorElementType != paddingType)
4807 "requires source element type and padding type to match.");
4811 if (!VectorType::isValidElementType(paddingType))
4812 return emitOpError(
"requires valid padding vector elemental type");
4815 if (paddingType != sourceElementType)
4817 "requires formal padding and source of the same elemental type");
4821 [&](Twine t) {
return emitOpError(t); });
4828 Type TransferReadOp::getExpectedMaskType() {
4836 return cast<VectorType>(getVector().
getType());
4839 template <
typename TransferOp>
4840 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4843 if (op.getShapedType().isDynamicDim(indicesIdx))
4845 Value index = op.getIndices()[indicesIdx];
4847 if (!cstOp.has_value())
4850 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4851 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4853 return cstOp.value() + vectorSize <= sourceSize;
4856 template <
typename TransferOp>
4860 if (op.getTransferRank() == 0)
4865 newInBounds.reserve(op.getTransferRank());
4870 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4872 if (op.isDimInBounds(i)) {
4873 newInBounds.push_back(
true);
4878 bool inBounds =
false;
4879 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4882 dimExpr.getPosition());
4883 nonBcastDims.push_back(i);
4886 newInBounds.push_back(inBounds);
4894 bool allNonBcastDimsInBounds = llvm::all_of(
4895 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4896 if (allNonBcastDimsInBounds) {
4899 newInBounds[idx] =
true;
4911 template <
typename TransferOp>
4913 auto mask = op.getMask();
4920 op.getMaskMutable().clear();
4934 static Value foldRAW(TransferReadOp readOp) {
4935 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4937 auto defWrite = readOp.getBase().
getDefiningOp<vector::TransferWriteOp>();
4940 return defWrite.getVector();
4942 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4943 cast<VectorTransferOpInterface>(readOp.getOperation())))
4945 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4951 if (
Value vec = foldRAW(*
this))
4965 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4969 void TransferReadOp::getEffects(
4972 if (llvm::isa<MemRefType>(getShapedType()))
4978 if (hasPureTensorSemantics())
5006 struct TransferReadAfterWriteToBroadcast
5012 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5016 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5019 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5023 if (readOp.getTransferChunkAccessed() !=
5024 defWrite.getTransferChunkAccessed())
5043 if (readOp.getMask() || defWrite.getMask())
5046 if (readOp.getIndices() != defWrite.getIndices())
5049 Value vec = defWrite.getVector();
5069 broadcastShape[pos.value()] = destShape[pos.index()];
5070 broadcastScalableFlags[pos.value()] =
5071 readOp.getVectorType().getScalableDims()[pos.index()];
5074 broadcastShape, defWrite.getVectorType().getElementType(),
5075 broadcastScalableFlags);
5076 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5087 results.
add<TransferReadAfterWriteToBroadcast>(context);
5097 AffineMapAttr permutationMapAttr,
5099 ArrayAttr inBoundsAttr) {
5100 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5101 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5102 mask, inBoundsAttr);
5108 AffineMapAttr permutationMapAttr,
5109 ArrayAttr inBoundsAttr) {
5110 build(builder, result, vector, dest, indices, permutationMapAttr,
5111 Value(), inBoundsAttr);
5122 (inBounds && !inBounds.value().empty())
5125 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5126 build(builder, result, vector, dest, indices, permutationMapAttr,
5127 Value(), inBoundsAttr);
5135 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5137 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5138 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5154 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5159 if (types.size() != 2)
5160 return parser.
emitError(typesLoc,
"requires two types");
5162 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5164 return parser.
emitError(typesLoc,
"requires vector type");
5165 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5166 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5167 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5168 auto permMapAttrName =
5169 TransferWriteOp::getPermutationMapAttrName(result.
name);
5173 if (shapedType.getRank() <
5176 "expected a custom permutation_map when "
5177 "rank(source) != rank(destination)");
5181 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5183 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
5185 if (!inBoundsAttr) {
5194 if (hasMask.succeeded()) {
5195 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5197 maskInfo.
location,
"does not support masks with vector element type");
5200 "expected the same rank for the vector and the "
5201 "results of the permutation map");
5207 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5209 {1, 1, static_cast<int32_t>(indexInfo.size()),
5210 static_cast<int32_t>(hasMask.succeeded())}));
5211 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5218 p <<
", " << getMask();
5225 ShapedType shapedType = getShapedType();
5227 VectorType maskType = getMaskType();
5228 auto permutationMap = getPermutationMap();
5229 VectorType inferredMaskType =
5233 if (llvm::size(
getIndices()) != shapedType.getRank())
5234 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5238 if (hasBroadcastDim())
5239 return emitOpError(
"should not have broadcast dimensions");
5242 shapedType, vectorType, maskType,
5243 inferredMaskType, permutationMap, getInBounds())))
5247 [&](Twine t) {
return emitOpError(t); });
5256 Type TransferWriteOp::getExpectedMaskType() {
5263 Value TransferWriteOp::getVector() {
return getOperand(0); }
5265 return cast<VectorType>(getValueToStore().
getType());
5288 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5292 if (write.getTransferRank() == 0)
5294 auto rankedTensorType =
5295 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5297 if (!rankedTensorType)
5300 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5304 if (read.getTransferRank() == 0)
5307 if (!read.getPermutationMap().isMinorIdentity() ||
5308 !write.getPermutationMap().isMinorIdentity())
5311 if (read.getTransferRank() != write.getTransferRank())
5314 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5317 if (read.getBase().getType() != rankedTensorType)
5320 if (read.getVectorType() != write.getVectorType())
5323 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5326 auto isNotConstantZero = [](
Value v) {
5328 return !cstOp.has_value() || cstOp.value() != 0;
5330 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5331 llvm::any_of(write.getIndices(), isNotConstantZero))
5334 results.push_back(read.getBase());
5338 static bool checkSameValueWAR(vector::TransferReadOp read,
5339 vector::TransferWriteOp write) {
5340 return read.getBase() == write.getBase() &&
5341 read.getIndices() == write.getIndices() &&
5342 read.getPermutationMap() == write.getPermutationMap() &&
5343 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5360 static LogicalResult foldWAR(TransferWriteOp write,
5362 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5364 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5368 if (!checkSameValueWAR(read, write))
5370 results.push_back(read.getBase());
5374 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5376 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5378 if (succeeded(foldWAR(*
this, results)))
5390 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5394 void TransferWriteOp::getEffects(
5397 if (llvm::isa<MemRefType>(getShapedType()))
5403 if (hasPureTensorSemantics())
5438 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5440 vector::TransferWriteOp writeToModify = writeOp;
5442 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5446 writeToModify.getBaseMutable().assign(defWrite.getBase());
5451 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5452 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5456 if (!defWrite->hasOneUse())
5458 writeToModify = defWrite;
5459 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5488 struct SwapExtractSliceOfTransferWrite
5495 if (!insertOp.hasUnitStride())
5498 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5499 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5501 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5502 if (!transferOp || !transferOp->hasOneUse())
5507 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5509 "use-def chain is rank-reducing");
5513 if (!extractOp.hasZeroOffset()) {
5515 "ExtractSliceOp has non-zero offset");
5519 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5523 "TranferWriteOp has non-zero offset");
5527 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5529 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5532 for (
auto [insertSize, extractSize] :
5533 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5536 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5541 assert(transferOp.getVectorType().hasStaticShape() &&
5542 "expected vector to have a static shape");
5545 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5546 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5548 insertOp,
"TransferWriteOp may not write the full tensor.");
5554 auto newExtractOp = tensor::ExtractSliceOp::create(
5555 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5556 insertOp.getDest(), insertOp.getMixedOffsets(),
5557 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5558 auto newTransferWriteOp = TransferWriteOp::create(
5559 rewriter, transferOp.getLoc(), transferOp.getVector(),
5560 newExtractOp.getResult(), transferOp.getIndices(),
5561 transferOp.getPermutationMapAttr(),
5564 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5574 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5581 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5583 MemRefType memRefTy) {
5586 if (!vecTy.isScalable() &&
5587 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5590 if (!memRefTy.isLastDimUnitStride())
5591 return op->
emitOpError(
"most minor memref dim must have unit stride");
5599 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5602 if (memRefTy.getRank() < resVecTy.getRank())
5604 "destination memref has lower rank than the result vector");
5607 Type memElemTy = memRefTy.getElementType();
5608 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5609 if (memVecTy != resVecTy)
5610 return emitOpError(
"base memref and result vector types should match");
5611 memElemTy = memVecTy.getElementType();
5614 if (resVecTy.getElementType() != memElemTy)
5615 return emitOpError(
"base and result element types should match");
5616 if (llvm::size(
getIndices()) != memRefTy.getRank())
5617 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5627 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5639 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5642 if (memRefTy.getRank() < valueVecTy.getRank())
5643 return emitOpError(
"source memref has lower rank than the vector to store");
5646 Type memElemTy = memRefTy.getElementType();
5647 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5648 if (memVecTy != valueVecTy)
5650 "base memref and valueToStore vector types should match");
5651 memElemTy = memVecTy.getElementType();
5654 if (valueVecTy.getElementType() != memElemTy)
5655 return emitOpError(
"base and valueToStore element type should match");
5656 if (llvm::size(
getIndices()) != memRefTy.getRank())
5657 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5661 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5666 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5675 VectorType maskVType = getMaskVectorType();
5676 VectorType passVType = getPassThruVectorType();
5680 if (resVType.getElementType() != memType.getElementType())
5681 return emitOpError(
"base and result element type should match");
5682 if (llvm::size(
getIndices()) != memType.getRank())
5683 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5684 if (resVType.getShape() != maskVType.getShape())
5685 return emitOpError(
"expected result shape to match mask shape");
5686 if (resVType != passVType)
5687 return emitOpError(
"expected pass_thru of same type as result type");
5700 load, load.getType(), load.getBase(), load.getIndices());
5703 rewriter.
replaceOp(load, load.getPassThru());
5708 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5715 results.
add<MaskedLoadFolder>(context);
5729 VectorType maskVType = getMaskVectorType();
5733 if (valueVType.getElementType() != memType.getElementType())
5734 return emitOpError(
"base and valueToStore element type should match");
5735 if (llvm::size(
getIndices()) != memType.getRank())
5736 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5737 if (valueVType.getShape() != maskVType.getShape())
5738 return emitOpError(
"expected valueToStore shape to match mask shape");
5751 store, store.getValueToStore(), store.getBase(), store.getIndices());
5759 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5766 results.
add<MaskedStoreFolder>(context);
5769 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5779 VectorType indVType = getIndexVectorType();
5780 VectorType maskVType = getMaskVectorType();
5782 ShapedType baseType = getBaseType();
5784 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5785 return emitOpError(
"requires base to be a memref or ranked tensor type");
5787 if (resVType.getElementType() != baseType.getElementType())
5788 return emitOpError(
"base and result element type should match");
5789 if (llvm::size(getOffsets()) != baseType.getRank())
5790 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5791 if (resVType.getShape() != indVType.getShape())
5792 return emitOpError(
"expected result dim to match indices dim");
5793 if (resVType.getShape() != maskVType.getShape())
5794 return emitOpError(
"expected result dim to match mask dim");
5795 if (resVType != getPassThruVectorType())
5796 return emitOpError(
"expected pass_thru of same type as result type");
5804 Type GatherOp::getExpectedMaskType() {
5805 auto vecType = this->getIndexVectorType();
5808 vecType.getScalableDims());
5811 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5816 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5817 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5818 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5829 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5842 rewriter.
replaceOp(gather, gather.getPassThru());
5847 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5858 if (!isa<MemRefType>(op.getBase().getType()))
5861 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
5865 op.getOffsets(), op.getMask(),
5874 results.
add<GatherFolder, FoldContiguousGather>(context);
5882 VectorType indVType = getIndexVectorType();
5883 VectorType maskVType = getMaskVectorType();
5887 if (valueVType.getElementType() != memType.getElementType())
5888 return emitOpError(
"base and valueToStore element type should match");
5889 if (llvm::size(getOffsets()) != memType.getRank())
5890 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5891 if (valueVType.getShape() != indVType.getShape())
5892 return emitOpError(
"expected valueToStore dim to match indices dim");
5893 if (valueVType.getShape() != maskVType.getShape())
5894 return emitOpError(
"expected valueToStore dim to match mask dim");
5913 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5924 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
5928 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
5936 results.
add<ScatterFolder, FoldContiguousScatter>(context);
5944 VectorType maskVType = getMaskVectorType();
5945 VectorType passVType = getPassThruVectorType();
5949 if (resVType.getElementType() != memType.getElementType())
5950 return emitOpError(
"base and result element type should match");
5951 if (llvm::size(
getIndices()) != memType.getRank())
5952 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5953 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5954 return emitOpError(
"expected result dim to match mask dim");
5955 if (resVType != passVType)
5956 return emitOpError(
"expected pass_thru of same type as result type");
5969 expand, expand.getType(), expand.getBase(), expand.getIndices());
5972 rewriter.
replaceOp(expand, expand.getPassThru());
5977 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5984 results.
add<ExpandLoadFolder>(context);
5992 VectorType maskVType = getMaskVectorType();
5996 if (valueVType.getElementType() != memType.getElementType())
5997 return emitOpError(
"base and valueToStore element type should match");
5998 if (llvm::size(
getIndices()) != memType.getRank())
5999 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6000 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6001 return emitOpError(
"expected valueToStore dim to match mask dim");
6006 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6014 compress, compress.getValueToStore(), compress.getBase(),
6015 compress.getIndices());
6023 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6030 results.
add<CompressStoreFolder>(context);
6039 setResultRanges(getResult(), argRanges.front());
6044 VectorType sourceType = getSourceVectorType();
6045 VectorType resultType = getResultVectorType();
6048 if (sourceType.getElementType() != resultType.getElementType())
6049 return emitOpError(
"has different source and result element types");
6052 int64_t sourceNElms = sourceType.getNumElements();
6053 int64_t resultNElms = resultType.getNumElements();
6054 if (sourceNElms != resultNElms) {
6055 return emitOpError() <<
"has different number of elements at source ("
6056 << sourceNElms <<
") and result (" << resultNElms
6061 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6062 int64_t resultNScalableDims = resultType.getNumScalableDims();
6063 if (sourceNScalableDims != resultNScalableDims)
6064 return emitOpError() <<
"has different number of scalable dims at source ("
6065 << sourceNScalableDims <<
") and result ("
6066 << resultNScalableDims <<
")";
6075 static bool isOrderPreserving(TransposeOp transpose) {
6077 VectorType sourceType = transpose.getSourceVectorType();
6080 auto isNonScalableUnitDim = [&](int64_t dim) {
6081 return inShape[dim] == 1 && !inDimIsScalable[dim];
6083 int64_t current = 0;
6084 for (
auto p : permutation) {
6085 if (!isNonScalableUnitDim(p)) {
6097 VectorType resultType =
getType();
6100 if (getSource().
getType() == resultType)
6104 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6105 setOperand(precedingShapeCast.getSource());
6110 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6111 if (isOrderPreserving(transpose)) {
6112 setOperand(transpose.getVector());
6120 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6121 if (bcastOp.getSourceType() == resultType)
6122 return bcastOp.getSource();
6126 if (
auto denseAttr =
6127 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6128 return denseAttr.reshape(
getType());
6131 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6144 static VectorType trimTrailingOneDims(VectorType oldType) {
6151 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6152 newShape = newShape.drop_back(1);
6153 newScalableDims = newScalableDims.drop_back(1);
6158 if (newShape.empty()) {
6159 newShape = oldShape.take_back();
6160 newScalableDims = oldScalableDims.take_back();
6163 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6178 class ShapeCastCreateMaskFolderTrailingOneDim final
6185 Value shapeOpSrc = shapeOp->getOperand(0);
6186 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6187 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6188 if (!createMaskOp && !constantMaskOp)
6191 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6192 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6194 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6195 if (newVecType != shapeOpResTy)
6198 auto numDimsToDrop =
6199 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6206 auto maskOperands = createMaskOp.getOperands();
6207 auto numMaskOperands = maskOperands.size();
6210 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6212 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6213 if (!constant || (constant.value() != 1))
6217 maskOperands.drop_back(numDimsToDrop);
6224 if (constantMaskOp) {
6225 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6226 auto numMaskOperands = maskDimSizes.size();
6229 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6231 if (maskDimSizes[i] != 1)
6235 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6249 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6256 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6260 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6261 bool srcIsScalar = !srcVectorType;
6269 if (srcVectorType) {
6270 if (srcVectorType.getNumElements() ==
6271 shapeCastOp.getResultVectorType().getNumElements()) {
6273 shapeCastOp, shapeCastOp.getResultVectorType(),
6274 broadcastOp.getSource());
6285 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6287 BroadcastableToResult::Success) {
6289 shapeCastOp, dstVectorType, broadcastOp.getSource());
6301 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6310 auto sourceVectorType = getSourceVectorType();
6311 auto resultVectorType = getResultVectorType();
6313 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6314 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6315 return emitOpError(
"dimension size mismatch at: ") << i;
6318 DataLayout dataLayout = DataLayout::closest(*
this);
6319 auto sourceElementBits =
6321 auto resultElementBits =
6324 if (sourceVectorType.getRank() == 0) {
6325 if (sourceElementBits != resultElementBits)
6326 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6327 "types must be equal");
6328 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6329 resultElementBits * resultVectorType.getShape().back()) {
6331 "source/result bitwidth of the minor 1-D vectors must be equal");
6343 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6344 if (getResult().
getType() == otherOp.getSource().getType())
6345 return otherOp.getSource();
6347 setOperand(otherOp.getSource());
6351 Attribute sourceConstant = adaptor.getSource();
6352 if (!sourceConstant)
6355 Type srcElemType = getSourceVectorType().getElementType();
6356 Type dstElemType = getResultVectorType().getElementType();
6358 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6359 if (floatPack.isSplat()) {
6360 auto splat = floatPack.getSplatValue<FloatAttr>();
6363 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6364 uint32_t bits =
static_cast<uint32_t
>(
6365 splat.getValue().bitcastToAPInt().getZExtValue());
6367 bits = (bits << 16) | (bits & 0xffff);
6368 APInt intBits(32, bits);
6369 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6375 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6376 if (intPack.isSplat()) {
6377 auto splat = intPack.getSplatValue<IntegerAttr>();
6379 if (llvm::isa<IntegerType>(dstElemType)) {
6384 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6385 APInt intBits = splat.getValue().zext(dstBitWidth);
6388 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6389 intBits = (intBits << srcBitWidth) | intBits;
6404 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6407 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6416 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6417 VectorType vectorType =
6421 memRefType.getMemorySpace()));
6425 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6426 if (!canonicalType.getLayout().isIdentity())
6427 return emitOpError(
"expects operand to be a memref with identity layout");
6428 if (!getResultMemRefType().getLayout().isIdentity())
6429 return emitOpError(
"expects result to be a memref with identity layout");
6430 if (getResultMemRefType().getMemorySpace() !=
6432 return emitOpError(
"expects result in same memory space");
6435 auto resultType = getResultMemRefType();
6439 "expects result and operand with same underlying scalar type: ")
6441 if (extractShape(sourceType) != extractShape(resultType))
6443 "expects concatenated result and operand shapes to be equal: ")
6454 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6457 for (
unsigned i = 0; i < permutation.size(); ++i) {
6458 transposedShape[i] = vt.getShape()[permutation[i]];
6459 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6464 transposedScalableDims));
6469 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6472 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6473 return splat.reshape(getResultVectorType());
6476 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6490 if (getSourceVectorType() == getResultVectorType() &&
6491 isOrderPreserving(*
this))
6498 VectorType vectorType = getSourceVectorType();
6499 VectorType resultType = getResultVectorType();
6500 int64_t rank = resultType.getRank();
6501 if (vectorType.getRank() != rank)
6502 return emitOpError(
"vector result rank mismatch: ") << rank;
6505 int64_t size = perm.size();
6507 return emitOpError(
"transposition length mismatch: ") << size;
6510 if (ta.value() < 0 || ta.value() >= rank)
6511 return emitOpError(
"transposition index out of range: ") << ta.value();
6512 if (seen[ta.value()])
6513 return emitOpError(
"duplicate position index: ") << ta.value();
6514 seen[ta.value()] =
true;
6515 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6516 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6521 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6522 return llvm::to_vector<4>(getResultVectorType().
getShape());
6527 setResultRanges(getResult(), argRanges.front());
6533 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6543 for (
auto index : permutation2)
6544 result.push_back(permutation1[index]);
6549 vector::TransposeOp parentTransposeOp =
6550 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6551 if (!parentTransposeOp)
6555 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6558 transposeOp, transposeOp.getResult().getType(),
6559 parentTransposeOp.getVector(), permutation);
6571 Value splat = getScalarSplatSource(transposeOp.getVector());
6576 transposeOp, transposeOp.getResultVectorType(), splat);
6582 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6588 Value transposeSrc = transpOp.getVector();
6589 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6590 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6591 if (!createMaskOp && !constantMaskOp)
6599 auto maskOperands = createMaskOp.getOperands();
6604 transpOp, transpOp.getResultVectorType(), newOperands);
6609 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6613 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6619 class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6626 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6629 if (!isOrderPreserving(transposeOp))
6632 VectorType resultType = transposeOp.getType();
6639 shapeCastOp.getSource());
6669 class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
6682 "not preceded by a broadcast");
6685 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
6686 VectorType outputType = transpose.getResultVectorType();
6689 bool inputIsScalar = !inputType;
6690 if (inputIsScalar) {
6698 int64_t inputRank = inputType.getRank();
6699 int64_t outputRank = transpose.getType().getRank();
6700 int64_t deltaRank = outputRank - inputRank;
6703 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6704 bool notOne = inputShape[inputIndex] != 1;
6705 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6706 bool groupEndFound = notOne || prevNotOne;
6707 if (groupEndFound) {
6708 int high = inputIndex + deltaRank;
6712 for (
int i = low; i < high; ++i) {
6713 if (permutation[i] < low || permutation[i] >= high) {
6715 transpose,
"permutation not local to group");
6729 vector::BroadcastableToResult::Success &&
6730 "not broadcastable directly to transpose output");
6741 void vector::TransposeOp::getCanonicalizationPatterns(
6743 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6744 FoldTransposeSplat, FoldTransposeBroadcast>(context);
6753 assert(
kind == ConstantMaskKind::AllTrue ||
6754 kind == ConstantMaskKind::AllFalse);
6755 build(builder, result, type,
6756 kind == ConstantMaskKind::AllTrue
6762 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6764 if (resultType.getRank() == 0) {
6765 if (getMaskDimSizes().size() != 1)
6766 return emitError(
"array attr must have length 1 for 0-D vectors");
6767 auto dim = getMaskDimSizes()[0];
6768 if (dim != 0 && dim != 1)
6769 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6774 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6776 "must specify array attr of size equal vector result rank");
6779 auto resultShape = resultType.getShape();
6780 auto resultScalableDims = resultType.getScalableDims();
6782 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6783 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6785 "array attr of size out of bounds of vector result dimension size");
6786 if (resultScalableDims[index] && maskDimSize != 0 &&
6787 maskDimSize != resultShape[index])
6789 "only supports 'none set' or 'all set' scalable dimensions");
6793 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6794 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6795 if (anyZeros && !allZeros)
6796 return emitOpError(
"expected all mask dim sizes to be zeros, "
6797 "as a result of conjunction with zero mask dim");
6801 bool ConstantMaskOp::isAllOnesMask() {
6804 if (resultType.getRank() == 0) {
6805 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6806 return getMaskDimSizes()[0] == 1;
6808 for (
const auto [resultSize, maskDimSize] :
6809 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6810 if (maskDimSize < resultSize)
6816 OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
6820 auto createBoolSplat = [&](
bool x) {
6826 if (vectorSizes.empty()) {
6827 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
6828 return createBoolSplat(bounds[0] == 1);
6831 if (bounds == vectorSizes)
6832 return createBoolSplat(
true);
6833 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
6834 return createBoolSplat(
false);
6847 build(builder, result, type, operands);
6851 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6853 if (vectorType.getRank() == 0) {
6854 if (getNumOperands() != 1)
6856 "must specify exactly one operand for 0-D create_mask");
6857 }
else if (getNumOperands() !=
6858 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6860 "must specify an operand for each result vector dimension");
6896 VectorType maskType = createMaskOp.getVectorType();
6898 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6901 constexpr std::array<int64_t, 1> rankZeroShape{1};
6902 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6903 if (maskType.getRank() == 0) {
6904 maskTypeDimSizes = rankZeroShape;
6905 maskTypeDimScalableFlags = rankZeroScalableDims;
6911 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6916 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6918 constantDims.push_back(*intSize);
6922 if (vscaleMultiplier < maskTypeDimSizes[i])
6924 constantDims.push_back(*vscaleMultiplier);
6931 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6932 value = std::clamp<int64_t>(value, 0, maskDimSize);
6935 if (llvm::is_contained(constantDims, 0))
6936 constantDims.assign(constantDims.size(), 0);
6949 results.
add<CreateMaskFolder>(context);
6960 assert(maskRegionBuilder &&
6961 "builder callback for 'maskRegion' must be present");
6967 maskRegionBuilder(builder, maskableOp);
6974 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6982 build(builder, result, mask, maskableOp, maskRegionBuilder);
7003 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7010 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
7024 result.
types.append(resultTypes);
7030 if (parsePassthru.succeeded()) {
7031 if (resultTypes.empty())
7034 "expects a result if passthru operand is provided");
7044 p <<
" " << getMask();
7046 p <<
", " << getPassthru();
7050 Block *singleBlock = &getMaskRegion().getBlocks().
front();
7057 p <<
" : " << getMask().getType();
7058 if (getNumResults() > 0)
7059 p <<
" -> " << getResultTypes();
7066 MaskOp>::ensureTerminator(region, builder, loc);
7072 if (isa<vector::YieldOp>(block.
back()))
7081 MaskOp>::ensureTerminator(region, builder, loc);
7088 opBuilder.setInsertionPointToEnd(&block);
7089 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7094 Block &block = getMaskRegion().getBlocks().
front();
7096 return emitOpError(
"expects a terminator within the mask region");
7099 if (numMaskRegionOps > 2)
7100 return emitOpError(
"expects only one operation to mask");
7103 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7105 return emitOpError(
"expects a terminator within the mask region");
7107 if (terminator->getNumOperands() != getNumResults())
7109 "expects number of results to match mask region yielded values");
7112 if (numMaskRegionOps == 1)
7115 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7117 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7121 return emitOpError(
"expects number of results to match maskable operation "
7122 "number of results");
7124 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7125 return emitOpError(
"expects all the results from the MaskableOpInterface "
7126 "to match all the values returned by the terminator");
7128 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7130 "expects result type to match maskable operation result type");
7133 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
7134 return emitOpError(
"multiple vector results not supported");
7137 Type expectedMaskType = maskableOp.getExpectedMaskType();
7138 if (getMask().
getType() != expectedMaskType)
7139 return emitOpError(
"expects a ")
7140 << expectedMaskType <<
" mask for the maskable operation";
7143 Value passthru = getPassthru();
7145 if (!maskableOp.supportsPassthru())
7147 "doesn't expect a passthru argument for this maskable operation");
7150 return emitOpError(
"expects result when passthru argument is provided");
7153 return emitOpError(
"expects passthru type to match result type");
7173 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7175 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7178 Block *block = maskOp.getMaskBlock();
7179 auto terminator = cast<vector::YieldOp>(block->
front());
7180 if (terminator.getNumOperands() == 0) {
7186 llvm::append_range(results, terminator.getOperands());
7190 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7192 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7200 Operation *maskableOp = getMaskableOp();
7204 llvm::append_range(results, maskableOp->
getResults());
7225 if (!maskOp.isEmpty())
7228 if (!maskOp.hasPassthru())
7231 Block *block = maskOp.getMaskBlock();
7232 auto terminator = cast<vector::YieldOp>(block->
front());
7233 assert(terminator.getNumOperands() == 1 &&
7234 "expected one result when passthru is provided");
7237 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7238 terminator.getOperand(0), maskOp.getPassthru());
7246 results.
add<CanonializeEmptyMaskOp>(context);
7253 Block *block = getMaskBlock();
7257 return &block->
front();
7261 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
7268 VectorType srcType = getSourceType();
7269 VectorType initialType = getInitialValueType();
7271 int64_t srcRank = srcType.getRank();
7272 int64_t reductionDim = getReductionDim();
7273 if (reductionDim >= srcRank)
7274 return emitOpError(
"reduction dimension ")
7275 << reductionDim <<
" has to be less than " << srcRank;
7278 int64_t initialValueRank = initialType.getRank();
7279 if (initialValueRank != srcRank - 1)
7280 return emitOpError(
"initial value rank ")
7281 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7287 for (
int i = 0; i < srcRank; i++) {
7288 if (i != reductionDim)
7289 expectedShape.push_back(srcShape[i]);
7291 if (!llvm::equal(initialValueShapes, expectedShape)) {
7292 return emitOpError(
"incompatible input/initial value shapes");
7296 Type eltType = getDestType().getElementType();
7298 return emitOpError(
"unsupported reduction type ")
7299 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7308 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7309 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7310 StridedSliceConstantMaskFolder, TransposeFolder>(
7319 auto constOperand = adaptor.getInput();
7320 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
7335 splatOp.getOperand());
7341 results.
add<SplatToBroadcastPattern>(context);
7346 setResultRanges(getResult(), argRanges.front());
7351 arith::FastMathFlagsAttr fastmath,
7358 case CombiningKind::ADD:
7361 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7362 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7364 llvm_unreachable(
"invalid value types for ADD reduction");
7366 case CombiningKind::AND:
7370 case CombiningKind::MAXNUMF:
7371 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7372 "expected float values");
7373 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7375 case CombiningKind::MAXIMUMF:
7376 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7377 "expected float values");
7378 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7380 case CombiningKind::MINNUMF:
7381 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7382 "expected float values");
7383 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7385 case CombiningKind::MINIMUMF:
7386 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7387 "expected float values");
7388 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7390 case CombiningKind::MAXSI:
7394 case CombiningKind::MINSI:
7398 case CombiningKind::MAXUI:
7406 case CombiningKind::MUL:
7409 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7410 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7412 llvm_unreachable(
"invalid value types for MUL reduction");
7414 case CombiningKind::OR:
7418 case CombiningKind::XOR:
7424 assert(result &&
"unknown CombiningKind");
7434 auto resultType = cast<VectorType>(
getType());
7435 if (resultType.isScalable()) {
7438 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7439 APInt zero(bitwidth, 0);
7440 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7442 setResultRanges(getResult(), result);
7453 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7470 return MaskOp::create(builder, maskableOp->
getLoc(),
7473 return MaskOp::create(builder, maskableOp->
getLoc(),
7490 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
7491 mask, newValue, passthru);
7498 #define GET_ATTRDEF_CLASSES
7499 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7501 #define GET_OP_CLASSES
7502 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static bool isBroadcastLike(Operation *op)
All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'br...
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Attribute convertIntegerAttr(Attribute attr, Type expectedType)
Converts an IntegerAttr to have the specified type if needed.
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A set of arbitrary-precision integers representing bounds on a given integer value.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)