41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
74 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76 for (
bool b : denseElts.getValues<
bool>())
79 else if (!b && val <= 0)
93 auto shape = m.getType().getShape();
96 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97 if (maskIdx < dimSize)
110 auto maskOperands = m.getOperands();
111 for (
Value operand : maskOperands) {
112 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
127 vector::YieldOp::create(builder, loc);
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(elementType);
179 VectorType vectorType) {
180 unsigned elementVectorRank = 0;
181 VectorType elementVectorType =
182 llvm::dyn_cast<VectorType>(shapedType.getElementType());
183 if (elementVectorType)
184 elementVectorRank += elementVectorType.getRank();
185 return vectorType.getRank() - elementVectorRank;
189 VectorType vectorType) {
192 if (shapedType.getRank() == 0 &&
198 shapedType.getRank(),
200 shapedType.getContext());
207 vector::TransferReadOp read) {
208 auto readMask = read.getMask();
209 auto writeMask = write.getMask();
215 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
216 if (!couldBeSameSplat)
221 m_Constant<DenseElementsAttr>(&splatAttr)) ||
233 vector::TransferReadOp read) {
234 return !defWrite.hasOutOfBoundsDim() &&
235 defWrite.getIndices() == read.getIndices() &&
236 defWrite.getVectorType() == read.getVectorType() &&
237 defWrite.getPermutationMap() == read.getPermutationMap() &&
238 ((!defWrite.getMask() && !read.getMask()) ||
243 vector::TransferWriteOp priorWrite) {
244 return priorWrite.getIndices() == write.getIndices() &&
245 priorWrite.getMask() == write.getMask() &&
246 priorWrite.getVectorType() == write.getVectorType() &&
247 priorWrite.getPermutationMap() == write.getPermutationMap();
251 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
252 bool testDynamicValueUsingBounds) {
254 if (transferA.getVectorType() != transferB.getVectorType())
256 unsigned rankOffset = transferA.getLeadingShapedRank();
257 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
258 Value indexA = transferA.getIndices()[i];
259 Value indexB = transferB.getIndices()[i];
263 if (i < rankOffset) {
266 if (cstIndexA.has_value() && cstIndexB.has_value()) {
267 if (*cstIndexA != *cstIndexB)
271 if (testDynamicValueUsingBounds) {
274 FailureOr<uint64_t> delta =
276 if (succeeded(delta) && *delta != 0)
279 FailureOr<bool> testEqual =
281 if (succeeded(testEqual) && !testEqual.value())
287 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
288 if (cstIndexA.has_value() && cstIndexB.has_value()) {
289 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
290 if (distance >= vectorDim)
294 if (testDynamicValueUsingBounds) {
297 FailureOr<int64_t> delta =
299 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
302 FailureOr<int64_t> computeDelta =
304 if (succeeded(computeDelta)) {
305 if (
std::abs(computeDelta.value()) >= vectorDim)
315 VectorTransferOpInterface transferB,
316 bool testDynamicValueUsingBounds) {
317 if (transferA.getBase() != transferB.getBase())
320 testDynamicValueUsingBounds);
330 for (
auto [posInDim, dimSize, offsetInDim] :
331 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333 if (posInDim < dimSize + offsetInDim)
337 posInDim = offsetInDim;
347 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
349 assert(constOp &&
"Unexpected non-constant index");
350 return constOp.value();
360 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
361 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
362 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
372 llvm::transform(foldResults, std::back_inserter(values),
374 if (
auto attr = dyn_cast<Attribute>(foldResult))
376 builder, loc, cast<IntegerAttr>(attr).getInt())
379 return cast<Value>(foldResult);
390 auto lhs = mul.getLhs();
391 auto rhs = mul.getRhs();
392 if (lhs.getDefiningOp<vector::VectorScaleOp>())
394 if (rhs.getDefiningOp<vector::VectorScaleOp>())
403 if (intAttr.getType() == expectedType)
452 void VectorDialect::initialize() {
454 #define GET_ATTRDEF_LIST
455 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
460 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
463 addInterfaces<VectorInlinerInterface>();
465 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
466 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
468 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
470 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
471 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
472 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
480 if (isa<ub::PoisonAttrInterface>(value))
483 return arith::ConstantOp::materialize(builder, value, type, loc);
499 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
502 CombiningKind
kind) {
506 reductionDims.push_back(en.index());
507 build(builder, result,
kind, source, acc, reductionDims);
510 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
512 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
517 std::optional<SmallVector<int64_t, 4>>
518 MultiDimReductionOp::getShapeForUnroll() {
519 return llvm::to_vector<4>(getSourceVectorType().
getShape());
525 Type inferredReturnType;
526 auto sourceScalableDims = getSourceVectorType().getScalableDims();
527 for (
auto [dimIdx, dimSize] :
529 if (!llvm::any_of(getReductionDims(),
530 [dimIdx = dimIdx](int64_t reductionDimIdx) {
531 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
533 targetShape.push_back(dimSize);
534 scalableDims.push_back(sourceScalableDims[dimIdx]);
537 if (targetShape.empty())
538 inferredReturnType = getSourceVectorType().getElementType();
541 targetShape, getSourceVectorType().
getElementType(), scalableDims);
542 if (
getType() != inferredReturnType)
543 return emitOpError() <<
"destination type " <<
getType()
544 <<
" is incompatible with source type "
545 << getSourceVectorType();
551 Type MultiDimReductionOp::getExpectedMaskType() {
552 auto vecType = getSourceVectorType();
555 vecType.getScalableDims());
564 struct ElideUnitDimsInMultiDimReduction
568 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
571 for (
const auto &dim :
enumerate(shape)) {
572 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
580 if (reductionOp.isMasked()) {
582 rootOp = reductionOp.getMaskingOp();
583 mask = reductionOp.getMaskingOp().getMask();
585 rootOp = reductionOp;
588 Location loc = reductionOp.getLoc();
589 Value acc = reductionOp.getAcc();
591 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
593 VectorType newMaskType =
595 dstVecType.getScalableDims());
596 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
598 cast = vector::ShapeCastOp::create(
599 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
604 mask = vector::ExtractOp::create(rewriter, loc, mask);
605 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
610 cast,
nullptr, mask);
617 void MultiDimReductionOp::getCanonicalizationPatterns(
619 results.
add<ElideUnitDimsInMultiDimReduction>(context);
628 arith::FastMathFlags fastMathFlags) {
629 build(builder, result,
kind, vector,
Value(), fastMathFlags);
634 arith::FastMathFlags fastMathFlags) {
635 build(builder, result,
636 llvm::cast<VectorType>(vector.
getType()).getElementType(),
kind, vector,
642 int64_t rank = getSourceVectorType().getRank();
644 return emitOpError(
"unsupported reduction rank: ") << rank;
647 Type eltType = getDest().getType();
649 return emitOpError(
"unsupported reduction type '")
650 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
659 Type ReductionOp::getExpectedMaskType() {
660 auto vecType = getSourceVectorType();
663 vecType.getScalableDims());
670 case arith::AtomicRMWKind::addf:
671 case arith::AtomicRMWKind::addi:
672 return vector::ReductionOp::create(builder, vector.
getLoc(),
673 CombiningKind::ADD, vector);
674 case arith::AtomicRMWKind::mulf:
675 case arith::AtomicRMWKind::muli:
676 return vector::ReductionOp::create(builder, vector.
getLoc(),
677 CombiningKind::MUL, vector);
678 case arith::AtomicRMWKind::minimumf:
679 return vector::ReductionOp::create(builder, vector.
getLoc(),
680 CombiningKind::MINIMUMF, vector);
681 case arith::AtomicRMWKind::mins:
682 return vector::ReductionOp::create(builder, vector.
getLoc(),
683 CombiningKind::MINSI, vector);
684 case arith::AtomicRMWKind::minu:
685 return vector::ReductionOp::create(builder, vector.
getLoc(),
687 case arith::AtomicRMWKind::maximumf:
688 return vector::ReductionOp::create(builder, vector.
getLoc(),
689 CombiningKind::MAXIMUMF, vector);
690 case arith::AtomicRMWKind::maxs:
691 return vector::ReductionOp::create(builder, vector.
getLoc(),
692 CombiningKind::MAXSI, vector);
693 case arith::AtomicRMWKind::maxu:
694 return vector::ReductionOp::create(builder, vector.
getLoc(),
695 CombiningKind::MAXUI, vector);
696 case arith::AtomicRMWKind::andi:
697 return vector::ReductionOp::create(builder, vector.
getLoc(),
698 CombiningKind::AND, vector);
699 case arith::AtomicRMWKind::ori:
700 return vector::ReductionOp::create(builder, vector.
getLoc(),
701 CombiningKind::OR, vector);
710 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
711 return llvm::to_vector<4>(getSourceVectorType().
getShape());
718 LogicalResult matchAndRewrite(ReductionOp reductionOp,
723 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
726 if (maskableOp.isMasked()) {
728 rootOp = maskableOp.getMaskingOp();
729 mask = maskableOp.getMaskingOp().getMask();
731 rootOp = reductionOp;
734 auto vectorType = reductionOp.getSourceVectorType();
735 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
738 Location loc = reductionOp.getLoc();
740 mask = ExtractOp::create(rewriter, loc, mask);
741 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
743 if (
Value acc = reductionOp.getAcc())
746 reductionOp.getFastmathAttr(), mask);
756 results.
add<ElideSingleElementReduction>(context);
770 getIndexingMapsAttrName(result.
name),
774 getIteratorTypesAttrName(result.
name),
777 return IteratorTypeAttr::get(builder.getContext(), t);
783 ArrayAttr indexingMaps,
784 ArrayAttr iteratorTypes) {
785 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
786 ContractionOp::getDefaultKind());
791 ArrayAttr indexingMaps,
792 ArrayAttr iteratorTypes, CombiningKind
kind) {
809 DictionaryAttr dictAttr;
824 dictAttr.getValue().end());
830 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
832 if (!iteratorTypes) {
834 <<
"expected " << getIteratorTypesAttrName(result.
name)
835 <<
" array attribute";
840 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
841 auto maybeIteratorType = symbolizeIteratorType(s);
842 if (!maybeIteratorType.has_value())
843 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
845 iteratorTypeAttrs.push_back(
853 getKindAttrName(result.
name),
855 ContractionOp::getDefaultKind()));
857 if (masksInfo.empty())
859 if (masksInfo.size() != 2)
861 "expected zero or exactly 2 vector mask operands");
862 auto lhsType = llvm::cast<VectorType>(types[0]);
863 auto rhsType = llvm::cast<VectorType>(types[1]);
865 std::array<VectorType, 2> maskTypes = {
875 auto attrNames = getTraitAttrNames();
877 traitAttrsSet.insert_range(attrNames);
879 for (
auto attr : (*this)->getAttrs()) {
880 if (attr.getName() == getIteratorTypesAttrName()) {
882 llvm::cast<ArrayAttr>(attr.getValue())
883 .getAsValueRange<IteratorTypeAttr, IteratorType>();
889 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
893 attrs.emplace_back(getIteratorTypesAttrName(),
895 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
896 attrs.push_back(attr);
900 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
901 p << getRhs() <<
", " << getAcc();
904 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
909 const std::vector<std::pair<int64_t, int64_t>> &map) {
910 for (
auto &dimPair : map) {
911 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
912 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
913 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
920 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
922 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
923 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
926 for (
auto &dimPair : contractingDimMap) {
927 lhsContractingDimSet.insert(dimPair.first);
928 rhsContractingDimSet.insert(dimPair.second);
931 llvm::make_second_range(batchDimMap));
935 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
936 if (lhsContractingDimSet.count(i) > 0)
938 expectedResultDims.push_back(lhsType.getDimSize(i));
942 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
943 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
945 expectedResultDims.push_back(rhsType.getDimSize(i));
949 if (expectedResultDims.empty()) {
951 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
952 return op.emitOpError(
"invalid accumulator/result vector shape");
955 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
956 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
957 if (!resVectorType || !accVectorType)
958 return op.emitOpError(
"invalid accumulator/result vector shape");
964 AffineMap lhsMap = op.getIndexingMapsArray()[0];
965 AffineMap rhsMap = op.getIndexingMapsArray()[1];
967 return op.emitOpError(
968 "expected all dimensions to be either a LHS or a RHS dimension");
971 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
972 VectorType v = pair.first;
973 auto map = pair.second;
974 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
975 unsigned pos = map.getDimPosition(idx);
980 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
981 return op.emitOpError(
"expected all dimensions to get an extent as "
982 "either a LHS or a RHS dimension");
984 AffineMap resMap = op.getIndexingMapsArray()[2];
990 llvm::IsaPred<AffineConstantExpr>) &&
991 "expected constant extent along all dimensions.");
993 auto expectedShape = llvm::to_vector<4>(
995 return cast<AffineConstantExpr>(e).getValue();
999 resVectorType.getScalableDims());
1000 if (resVectorType != expected || accVectorType != expected)
1001 return op.emitOpError(
1002 "invalid accumulator/result vector shape, expected: ")
1009 VectorType lhsType = getLhsType();
1010 VectorType rhsType = getRhsType();
1011 Type accType = getAccType();
1012 Type resType = getResultType();
1014 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1015 if (!lhsType.getElementType().isSignlessInteger())
1016 return emitOpError(
"only supports signless integer types");
1020 if (getIndexingMapsArray().size() != 3)
1021 return emitOpError(
"expected an indexing map for each vector operand");
1026 unsigned numIterators = getIteratorTypes().getValue().size();
1028 auto index = it.index();
1029 auto map = it.value();
1030 if (map.getNumSymbols() != 0)
1031 return emitOpError(
"expected indexing map ")
1032 << index <<
" to have no symbols";
1033 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
1034 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1037 if (map.getNumDims() != numIterators)
1038 return emitOpError(
"expected indexing map ")
1039 << index <<
" to have " << numIterators <<
" number of inputs";
1040 if (map.getNumResults() != rank)
1041 return emitOpError(
"expected indexing map ")
1042 << index <<
" to have " << rank <<
" number of outputs";
1043 if (!map.isProjectedPermutation())
1044 return emitOpError(
"expected indexing map ")
1045 << index <<
" to be a projected permutation of its inputs";
1048 auto contractingDimMap = getContractingDimMap();
1049 auto batchDimMap = getBatchDimMap();
1052 if (contractingDimMap.empty())
1053 return emitOpError(
"expected at least one contracting dimension pair");
1056 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1057 return emitOpError(
"invalid contracting dimension map");
1061 return emitOpError(
"invalid batch dimension map");
1065 contractingDimMap, batchDimMap)))
1069 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1070 auto elementType = vectorType ? vectorType.getElementType() : resType;
1072 return emitOpError(
"unsupported contraction type");
1075 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1082 Type ContractionOp::getExpectedMaskType() {
1083 auto indexingMaps = this->getIndexingMapsArray();
1086 VectorType lhsType = this->getLhsType();
1087 VectorType rhsType = this->getRhsType();
1089 unsigned numVecDims = lhsIdxMap.
getNumDims();
1098 lhsType.getScalableDims()[dimIdx];
1103 rhsType.getScalableDims()[dimIdx];
1106 assert(ShapedType::isStaticShape(maskShape) &&
1107 "Mask shape couldn't be computed");
1111 maskShapeScalableDims);
1116 getIteratorTypesAttrName(), getKindAttrName()};
1126 static std::vector<std::pair<int64_t, int64_t>>
1128 IteratorType targetIteratorType,
MLIRContext *context) {
1129 std::vector<std::pair<int64_t, int64_t>> dimMap;
1131 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1132 if (iteratorType != targetIteratorType)
1138 if (lhsDim >= 0 && rhsDim >= 0)
1139 dimMap.emplace_back(lhsDim, rhsDim);
1144 void ContractionOp::getIterationBounds(
1146 auto lhsShape = getLhsType().getShape();
1147 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1152 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1153 if (iteratorType == IteratorType::reduction) {
1155 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1156 assert(lhsDimIndex >= 0);
1157 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1161 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1162 assert(resDimIndex >= 0);
1163 assert(resVectorType !=
nullptr);
1164 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1168 void ContractionOp::getIterationIndexMap(
1170 unsigned numMaps = getIndexingMapsArray().size();
1171 iterationIndexMap.resize(numMaps);
1173 auto index = it.index();
1174 auto map = it.value();
1175 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1176 auto dim = cast<AffineDimExpr>(map.getResult(i));
1177 iterationIndexMap[index][dim.getPosition()] = i;
1182 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1184 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1188 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1190 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1194 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1196 getIterationBounds(shape);
1218 template <
typename AddOpType>
1224 auto canonicalize = [&](
Value maybeContraction,
1225 Value otherOperand) -> vector::ContractionOp {
1226 vector::ContractionOp contractionOp =
1227 dyn_cast_or_null<vector::ContractionOp>(
1230 return vector::ContractionOp();
1231 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1232 contractionOp.getAcc().getDefiningOp())) {
1233 if (maybeZero.getValue() ==
1234 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1236 bvm.
map(contractionOp.getAcc(), otherOperand);
1237 auto newContraction =
1238 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1239 rewriter.
replaceOp(addOp, newContraction.getResult());
1240 return newContraction;
1243 return vector::ContractionOp();
1246 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1247 vector::ContractionOp
contract = canonicalize(a, b);
1249 return contract ? success() : failure();
1263 return index == poisonValue || (index >= 0 && index < maxIndex);
1272 setResultRanges(getResult(), argRanges.front());
1277 auto vectorTy = cast<VectorType>(source.
getType());
1282 Value source, int64_t position) {
1302 build(builder, result, source, dynamicPos,
1307 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1308 ExtractOp::Adaptor adaptor,
1310 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1311 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1312 vectorType.getRank()) {
1313 inferredReturnTypes.push_back(vectorType.getElementType());
1315 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1316 vectorType.getRank());
1318 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1319 vectorType.getScalableDims().drop_front(n)));
1327 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1328 return vectorType && vectorType.getShape().equals({1}) &&
1329 vectorType.getElementType() == r.front();
1331 if (l.size() == 1 && r.size() == 1 &&
1332 (isCompatible(l, r) || isCompatible(r, l)))
1338 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1339 if (resTy.getRank() == 0)
1341 "expected a scalar instead of a 0-d vector as the result type");
1344 auto dynamicMarkersCount =
1345 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1346 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1348 "mismatch between dynamic and static positions (kDynamic marker but no "
1349 "corresponding dynamic position) -- this can only happen due to an "
1350 "incorrect fold/rewrite");
1351 auto position = getMixedPosition();
1352 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1354 "expected position attribute of rank no greater than vector rank");
1356 if (
auto attr = dyn_cast<Attribute>(pos)) {
1357 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1359 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1360 return emitOpError(
"expected position attribute #")
1362 <<
" to be a non-negative integer smaller than the "
1363 "corresponding vector dimension or poison (-1)";
1370 template <
typename IntType>
1372 return llvm::to_vector<4>(llvm::map_range(
1373 arrayAttr.getAsRange<IntegerAttr>(),
1374 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1380 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1384 if (extractOp.hasDynamicPosition())
1388 ExtractOp currentOp = extractOp;
1390 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1391 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1394 if (currentOp.hasDynamicPosition())
1397 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1399 extractOp.setOperand(0, currentOp.getSource());
1402 std::reverse(globalPosition.begin(), globalPosition.end());
1403 extractOp.setStaticPosition(globalPosition);
1415 class ExtractFromInsertTransposeChainState {
1417 ExtractFromInsertTransposeChainState(ExtractOp e);
1426 template <
typename ContainerA,
typename ContainerB>
1427 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1428 return a.size() <= b.size() &&
1429 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1436 template <
typename ContainerA,
typename ContainerB>
1437 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1438 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1439 if (elemA < 0 || elemB < 0)
1454 void updateStateForNextIteration(
Value v) {
1461 LogicalResult handleTransposeOp();
1464 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1479 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1484 Value tryToFoldExtractOpInPlace(
Value source);
1486 ExtractOp extractOp;
1488 int64_t extractedRank;
1490 InsertOp nextInsertOp;
1491 TransposeOp nextTransposeOp;
1506 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1508 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1509 extractedRank(extractOp.getNumIndices()) {
1510 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1511 sentinels.reserve(vectorRank - extractedRank);
1512 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1513 sentinels.push_back(-(i + 1));
1515 extractOp.getStaticPosition().end());
1521 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1523 if (extractOp.hasDynamicPosition())
1526 if (!nextTransposeOp)
1529 nextTransposeOp.getPermutation(), extractOp.getContext()));
1536 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1539 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1546 res = nextInsertOp.getValueToStore();
1548 return success(canFold());
1555 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1557 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1570 res = nextInsertOp.getValueToStore();
1578 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1581 if (extractOp.hasDynamicPosition())
1585 bool nothingToFold = (source == extractOp.getSource());
1586 if (nothingToFold || !canFold())
1591 extractOp.setStaticPosition(
1593 extractOp.getSourceMutable().assign(source);
1594 return extractOp.getResult();
1598 Value ExtractFromInsertTransposeChainState::fold() {
1600 if (extractOp.hasDynamicPosition())
1603 Value valueToExtractFrom = extractOp.getSource();
1604 updateStateForNextIteration(valueToExtractFrom);
1605 while (nextInsertOp || nextTransposeOp) {
1608 if (succeeded(handleTransposeOp())) {
1609 valueToExtractFrom = nextTransposeOp.getVector();
1610 updateStateForNextIteration(valueToExtractFrom);
1616 if (succeeded(handleInsertOpWithMatchingPos(result)))
1621 if (succeeded(handleInsertOpWithPrefixPos(result)))
1622 return tryToFoldExtractOpInPlace(result);
1632 valueToExtractFrom = nextInsertOp.getDest();
1633 updateStateForNextIteration(valueToExtractFrom);
1636 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1641 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1642 auto vecType = dyn_cast<VectorType>(type);
1643 return vecType && vecType.getRank() == 0;
1653 if (isa<BroadcastOp, SplatOp>(op))
1656 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1664 VectorType srcType = shapeCast.getSourceVectorType();
1666 uint64_t srcRank = srcType.getRank();
1668 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1694 Operation *defOp = extractOp.getSource().getDefiningOp();
1701 if (extractOp.getType() == input.
getType())
1707 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1708 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1709 unsigned inputRank = inputType ? inputType.getRank() : 0;
1710 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1711 unsigned extractRank = extractType ? extractType.getRank() : 0;
1714 if (extractRank > inputRank)
1718 assert(inputType &&
"input must be a vector type because of previous checks");
1727 extractType.getShape() != inputShape.take_back(extractRank))
1732 unsigned deltaOverall = inputRank - extractRank;
1733 unsigned deltaBroadcast = broadcastRank - inputRank;
1737 for (
auto [i, size] :
llvm::enumerate(inputShape.take_front(deltaOverall))) {
1738 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1741 extractOp->setOperands(
1742 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1743 extractOp.setStaticPosition(staticPos);
1744 return extractOp.getResult();
1760 if (extractOp.hasDynamicPosition())
1763 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1768 if (shuffleOp.getResultVectorType().getRank() != 1)
1771 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1772 auto shuffleMask = shuffleOp.getMask();
1773 int64_t extractIdx = extractOp.getStaticPosition()[0];
1774 int64_t shuffleIdx = shuffleMask[extractIdx];
1777 if (shuffleIdx < inputVecSize) {
1778 extractOp.setOperand(0, shuffleOp.getV1());
1779 extractOp.setStaticPosition({shuffleIdx});
1781 extractOp.setOperand(0, shuffleOp.getV2());
1782 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1785 return extractOp.getResult();
1791 if (extractOp.hasDynamicPosition())
1794 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1799 auto getDimReverse = [](VectorType type, int64_t n) {
1800 return type.getShape().take_back(n + 1).front();
1802 int64_t destinationRank =
1803 llvm::isa<VectorType>(extractOp.getType())
1804 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1806 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1808 if (destinationRank > 0) {
1809 auto destinationType =
1810 llvm::cast<VectorType>(extractOp.getResult().getType());
1811 for (int64_t i = 0; i < destinationRank; i++) {
1815 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1816 getDimReverse(destinationType, i))
1823 std::reverse(extractedPos.begin(), extractedPos.end());
1826 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1827 strides.push_back(stride);
1829 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1832 int64_t position =
linearize(extractedPos, strides);
1836 int64_t numDimension =
1837 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1839 for (int64_t i = 0; i < numDimension; i++) {
1840 newStrides.push_back(stride);
1842 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1844 std::reverse(newStrides.begin(), newStrides.end());
1848 extractOp.setStaticPosition(newPosition);
1849 extractOp.setOperand(0, shapeCastOp.getSource());
1850 return extractOp.getResult();
1856 if (extractOp.hasDynamicPosition())
1859 auto extractStridedSliceOp =
1860 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1861 if (!extractStridedSliceOp)
1870 if (extractStridedSliceOp.hasNonUnitStrides())
1875 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1876 while (!sliceOffsets.empty()) {
1877 size_t lastOffset = sliceOffsets.size() - 1;
1878 if (sliceOffsets.back() != 0 ||
1879 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1880 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1882 sliceOffsets.pop_back();
1884 unsigned destinationRank = 0;
1885 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1886 destinationRank = vecType.getRank();
1889 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1890 sliceOffsets.size())
1894 assert(extractedPos.size() >= sliceOffsets.size());
1895 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1896 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1897 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1901 extractOp.setStaticPosition(extractedPos);
1902 return extractOp.getResult();
1908 if (extractOp.hasDynamicPosition())
1911 int64_t destinationRank =
1912 llvm::isa<VectorType>(extractOp.getType())
1913 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1915 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1925 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1926 insertOp.getSourceVectorType().getRank();
1927 if (destinationRank > insertOp.getSourceVectorType().getRank())
1929 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1932 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1933 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1936 bool disjoint =
false;
1938 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1939 int64_t start = insertOffsets[dim];
1941 (dim < insertRankDiff)
1943 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1944 int64_t end = start + size;
1945 int64_t offset = extractOffsets[dim];
1947 if (start <= offset && offset < end) {
1948 if (dim >= insertRankDiff)
1949 offsetDiffs.push_back(offset - start);
1959 int64_t srcRankDiff =
1960 insertOp.getSourceVectorType().getRank() - destinationRank;
1961 for (int64_t i = 0; i < destinationRank; i++) {
1962 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1963 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1967 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1970 extractOp.setStaticPosition(offsetDiffs);
1971 return extractOp.getResult();
1975 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1988 if (extractOp.hasDynamicPosition())
1992 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
1993 if (!fromElementsOp)
1997 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1998 if (vecType.isScalable())
2002 int64_t rank = vecType.getRank();
2004 if (extractOp.getType() != vecType.getElementType())
2006 assert(
static_cast<int64_t
>(indices.size()) == rank &&
2007 "unexpected number of indices");
2012 for (
int i = rank - 1; i >= 0; --i) {
2013 flatIndex += indices[i] * stride;
2014 stride *= vecType.getDimSize(i);
2016 return fromElementsOp.getElements()[flatIndex];
2021 template <
typename OpType,
typename AdaptorType>
2024 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2025 OperandRange dynamicPosition = op.getDynamicPosition();
2028 if constexpr (std::is_same_v<OpType, ExtractOp>)
2029 vectorShape = op.getSourceVectorType().getShape();
2034 if (!dynamicPosition.size())
2041 bool opChange =
false;
2042 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2043 if (ShapedType::isStatic(staticPosition[i]))
2045 Attribute positionAttr = dynamicPositionAttr[index];
2046 Value position = dynamicPosition[index++];
2047 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2048 int64_t value = attr.getInt();
2052 staticPosition[i] = attr.getInt();
2057 operands.push_back(position);
2061 op.setStaticPosition(staticPosition);
2062 op.getOperation()->setOperands(operands);
2064 return op.getResult();
2073 int64_t poisonVal) {
2074 if (!is_contained(staticPos, poisonVal))
2082 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2091 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2096 if (denseAttr.isSplat()) {
2098 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2103 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2104 if (vecTy.isScalable())
2107 if (extractOp.hasDynamicPosition()) {
2122 copy(extractOp.getStaticPosition(), completePositions.begin());
2125 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2128 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2130 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2133 newAttr = *denseValuesBegin;
2143 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2154 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2160 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2175 return inplaceFolded;
2188 Operation *defOp = extractOp.getSource().getDefiningOp();
2189 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2195 BroadcastableToResult::Success)
2211 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2215 VectorType extractedMaskType =
2216 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2218 if (!extractedMaskType)
2221 auto maskOperands = createMaskOp.getOperands();
2223 VectorType maskType = createMaskOp.getVectorType();
2225 bool containsUnknownDims =
false;
2228 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2230 int64_t pos = extractOpPos[dimIdx];
2231 Value operand = maskOperands[dimIdx];
2232 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2235 containsUnknownDims =
true;
2239 int64_t createMaskBound =
2240 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2242 if (pos != ShapedType::kDynamic) {
2245 allFalse |= pos >= createMaskBound;
2246 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2250 containsUnknownDims =
true;
2257 }
else if (!containsUnknownDims) {
2259 extractOp, extractedMaskType,
2260 maskOperands.drop_front(extractOpPos.size()));
2270 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2272 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2276 VectorType sourceType = castOp.getSourceVectorType();
2277 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2281 if (sourceType.getNumElements() != targetType.getNumElements())
2285 castOp.getSource());
2295 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2298 if (extractOp.hasDynamicPosition())
2302 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2307 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2308 if (!fromElementsOp)
2310 VectorType inputType = fromElementsOp.getType();
2313 if (resultType.isScalable() || inputType.isScalable())
2319 llvm::to_vector(extractOp.getStaticPosition());
2320 firstElementPos.append(resultType.getRank(), 0);
2323 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2324 flatIndex += firstElementPos[i] * stride;
2325 stride *= inputType.getDimSize(i);
2330 extractOp, resultType,
2331 fromElementsOp.getElements().slice(flatIndex,
2332 resultType.getNumElements()));
2340 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2341 results.
add(foldExtractFromShapeCastToShapeCast);
2342 results.
add(foldExtractFromFromElements);
2347 for (
auto attr : arrayAttr)
2348 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2355 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2366 if (operands.empty())
2369 return llvm::all_of(operands, [&](
Value operand) {
2371 return currentDef == defOp;
2386 static LogicalResult
2389 auto fromElementsOp =
2390 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2391 if (!fromElementsOp)
2394 llvm::append_range(results, fromElementsOp.getElements());
2398 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2404 ToElementsOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2405 ToElementsOp::Adaptor adaptor,
2407 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2408 Type elType = vecType.getElementType();
2409 inferredReturnTypes.append(vecType.getNumElements(), elType);
2430 OperandRange fromElemsOperands = fromElementsOp.getElements();
2431 if (fromElemsOperands.empty())
2434 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2442 Value toElementsInput = toElementsOp.getSource();
2443 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2444 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2445 return toElementsInput;
2465 if (llvm::any_of(elements, [](
Attribute attr) {
2466 return !attr || isa<ub::PoisonAttrInterface>(attr);
2471 auto destVecType = fromElementsOp.getDest().getType();
2472 auto destEltType = destVecType.getElementType();
2473 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2478 auto convertedElements =
2480 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
2490 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2504 static LogicalResult
2507 if (!llvm::all_equal(fromElementsOp.getElements()))
2510 fromElementsOp, fromElementsOp.getType(),
2511 fromElementsOp.getElements().front());
2537 using OpRewritePattern::OpRewritePattern;
2539 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2543 if (fromElements.getType().getNumElements() == 1)
2554 for (
auto [insertIndex, element] :
2558 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2561 "element not from vector.extract");
2566 if (insertIndex == 0) {
2567 source = extractOp.getSource();
2568 }
else if (extractOp.getSource() != source) {
2570 "element from different vector");
2574 int64_t rank = position.size();
2575 assert(rank == source.getType().getRank() &&
2576 "scalar extract must have full rank position");
2587 if (insertIndex == 0) {
2588 const int64_t numElms = fromElements.getType().getNumElements();
2589 int64_t numSuffixElms = 1;
2590 int64_t index = rank;
2591 while (index > 0 && position[index - 1] == 0 &&
2592 numSuffixElms < numElms) {
2593 numSuffixElms *= source.getType().getDimSize(index - 1);
2596 if (numSuffixElms != numElms) {
2598 fromElements,
"elements do not form a suffix of source");
2600 expectedPosition = llvm::to_vector(position);
2601 combinedPosition = position.drop_back(rank - index);
2605 else if (expectedPosition != position) {
2607 fromElements,
"elements not in ascending order (static order)");
2609 increment(expectedPosition, source.getType().getShape());
2612 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2613 fromElements.getLoc(), source, combinedPosition);
2616 fromElements, fromElements.getType(), extracted);
2624 for (
int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2626 if (indices[dim] < shape[dim])
2645 setResultRanges(getResult(), argRanges.front());
2648 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2649 return llvm::to_vector<4>(getResultVectorType().
getShape());
2657 int64_t rankDiff = dstShape.size() - srcShape.size();
2658 int64_t dstDim = rankDiff;
2660 for (
auto [s1, s2] :
2661 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2663 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2673 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2692 Value BroadcastOp::createOrFoldBroadcastOp(
2695 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2699 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2700 if (broadcastedDims.contains(i))
2702 checkShape.push_back(dstShape[i]);
2704 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2705 "ill-formed broadcastedDims contains values not confined to "
2710 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2714 if (!srcVectorType) {
2715 assert(checkShape.empty() &&
2716 "ill-formed createOrFoldBroadcastOp arguments");
2717 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2720 assert(srcVectorType.getShape().equals(checkShape) &&
2721 "ill-formed createOrFoldBroadcastOp arguments");
2732 broadcastShape.reserve(dstShape.size());
2748 int64_t nextSrcShapeDim = broadcastedDims.size();
2749 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2750 if (broadcastedDims.contains(i)) {
2755 broadcastShape.push_back(dstShape[i]);
2756 permutation[i] = broadcastShape.size() - 1;
2762 permutation[i] = nextSrcShapeDim++;
2766 llvm::append_range(broadcastShape, srcVectorType.getShape());
2771 "unexpected \"dim-1\" broadcast");
2773 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2775 vector::BroadcastableToResult::Success &&
2776 "must be broadcastable");
2780 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2781 if (permutation[i] != i)
2782 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2788 Type srcType, VectorType dstVectorType,
2789 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2791 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2793 return BroadcastableToResult::Success;
2795 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2797 return BroadcastableToResult::SourceTypeNotAVector;
2799 int64_t srcRank = srcVectorType.getRank();
2800 int64_t dstRank = dstVectorType.getRank();
2801 if (srcRank > dstRank)
2802 return BroadcastableToResult::SourceRankHigher;
2805 int64_t lead = dstRank - srcRank;
2806 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2809 bool foundMismatchingDims =
false;
2812 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2813 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2814 if (srcDim != 1 && srcDim != dstDim)
2815 foundMismatchingDims =
true;
2818 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2819 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2820 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2823 (srcDimScalableFlag != dstDimScalableFlag &&
2824 (srcDim != 1 || srcDimScalableFlag)))
2825 foundMismatchingDims =
true;
2827 if (foundMismatchingDims) {
2828 if (mismatchingDims !=
nullptr) {
2829 mismatchingDims->first.dim = srcDim;
2830 mismatchingDims->first.isScalable = srcDimScalableFlag;
2832 mismatchingDims->second.dim = dstDim;
2833 mismatchingDims->second.isScalable = dstDimScalableFlag;
2835 return BroadcastableToResult::DimensionMismatch;
2839 return BroadcastableToResult::Success;
2843 std::pair<VectorDim, VectorDim> mismatchingDims;
2845 getSourceType(), getResultVectorType(), &mismatchingDims);
2846 if (res == BroadcastableToResult::Success)
2848 if (res == BroadcastableToResult::SourceRankHigher)
2849 return emitOpError(
"source rank higher than destination rank");
2850 if (res == BroadcastableToResult::DimensionMismatch) {
2851 return emitOpError(
"dimension mismatch (")
2852 << (mismatchingDims.first.isScalable ?
"[" :
"")
2853 << mismatchingDims.first.dim
2854 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2855 << (mismatchingDims.second.isScalable ?
"[" :
"")
2856 << mismatchingDims.second.dim
2857 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2859 if (res == BroadcastableToResult::SourceTypeNotAVector)
2860 return emitOpError(
"source type is not a vector");
2861 llvm_unreachable(
"unexpected vector.broadcast op error");
2868 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2872 VectorType srcType = srcShapeCast.getSourceVectorType();
2873 VectorType destType = broadcastOp.getResultVectorType();
2876 BroadcastableToResult::Success)
2881 srcShapeCast.getResultVectorType().getShape();
2884 unsigned numTrailingDims =
std::min(srcShape.size(), shapecastShape.size());
2885 if (!llvm::equal(srcShape.take_back(numTrailingDims),
2886 shapecastShape.take_back(numTrailingDims)))
2889 assert(all_of(srcShape.drop_back(numTrailingDims),
2890 [](int64_t E) { return E == 1; }) &&
2891 all_of(shapecastShape.drop_back(numTrailingDims),
2892 [](int64_t E) { return E == 1; }) &&
2893 "ill-formed shape_cast");
2895 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
2900 if (getSourceType() == getResultVectorType())
2905 if (!adaptor.getSource())
2907 auto vectorType = getResultVectorType();
2908 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2909 if (vectorType.getElementType() != attr.getType())
2913 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2914 if (vectorType.getElementType() != attr.getType())
2918 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2920 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2933 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2937 broadcastOp.getResultVectorType(),
2938 srcBroadcast.getSource());
2948 results.
add<BroadcastFolder>(context);
2956 VectorType resultType = getResultVectorType();
2957 VectorType v1Type = getV1VectorType();
2958 VectorType v2Type = getV2VectorType();
2960 int64_t resRank = resultType.getRank();
2961 int64_t v1Rank = v1Type.getRank();
2962 int64_t v2Rank = v2Type.getRank();
2963 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2964 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2965 if (!wellFormed0DCase && !wellFormedNDCase)
2966 return emitOpError(
"rank mismatch");
2969 for (int64_t r = 1; r < v1Rank; ++r) {
2970 int64_t resDim = resultType.getDimSize(r);
2971 int64_t v1Dim = v1Type.getDimSize(r);
2972 int64_t v2Dim = v2Type.getDimSize(r);
2973 if (resDim != v1Dim || v1Dim != v2Dim)
2974 return emitOpError(
"dimension mismatch");
2978 int64_t maskLength = mask.size();
2979 if (maskLength <= 0)
2980 return emitOpError(
"invalid mask length");
2981 if (maskLength != resultType.getDimSize(0))
2982 return emitOpError(
"mask length mismatch");
2984 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2985 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2988 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2994 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2995 ShuffleOp::Adaptor adaptor,
2997 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2998 auto v1Rank = v1Type.getRank();
3002 shape.reserve(v1Rank);
3003 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3006 llvm::append_range(shape, v1Type.getShape().drop_front());
3007 inferredReturnTypes.push_back(
3012 template <
typename T>
3015 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3016 return value == expected++;
3020 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3021 auto v1Type = getV1VectorType();
3022 auto v2Type = getV2VectorType();
3024 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3025 "Vector shuffle does not support scalable vectors");
3029 if (v1Type.getRank() == 0)
3033 auto mask = getMask();
3040 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3041 if (!v1Attr || !v2Attr)
3045 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3046 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3047 if (isV1Poison && isV2Poison)
3052 if (v1Type.getRank() != 1)
3061 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3064 v2Elements = to_vector(v2DenseAttr.getValues<
Attribute>());
3065 poisonElement = v2Elements[0];
3068 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3071 v1Elements = to_vector(v1DenseAttr.getValues<
Attribute>());
3072 poisonElement = v1Elements[0];
3076 int64_t v1Size = v1Type.getDimSize(0);
3077 for (int64_t maskIdx : mask) {
3080 if (maskIdx == ShuffleOp::kPoisonIndex) {
3081 indexedElm = poisonElement;
3083 if (maskIdx < v1Size)
3084 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3086 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3089 results.push_back(indexedElm);
3104 VectorType v1VectorType = shuffleOp.getV1VectorType();
3106 if (v1VectorType.getRank() > 0)
3108 if (mask.size() != 1)
3129 static Value getScalarSplatSource(
Value value) {
3136 if (
auto splat = dyn_cast<vector::SplatOp>(defOp))
3137 return splat.getInput();
3139 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3146 if (isa<VectorType>(
broadcast.getSourceType()))
3160 Value splat = getScalarSplatSource(op.getV1());
3161 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3177 VectorType resultType = op.getResultVectorType();
3178 if (resultType.isScalable())
3180 op,
"ShuffleOp can't represent a scalable interleave");
3182 if (resultType.getRank() != 1)
3184 op,
"ShuffleOp can't represent an n-D interleave");
3186 VectorType sourceType = op.getV1VectorType();
3187 if (sourceType != op.getV2VectorType() ||
3188 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3190 op,
"ShuffleOp types don't match an interleave");
3194 int64_t resultVectorSize = resultType.getNumElements();
3195 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3196 int64_t maskValueA = shuffleMask[i * 2];
3197 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3198 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3200 "ShuffleOp mask not interleaving");
3212 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3222 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3227 auto vectorTy = cast<VectorType>(dest.
getType());
3228 build(builder, result, source, dest,
3233 Value source,
Value dest, int64_t position) {
3246 posVals.reserve(position.size());
3247 llvm::transform(position, std::back_inserter(posVals),
3249 build(builder, result, source, dest, posVals);
3258 build(builder, result, source, dest, dynamicPos,
3263 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3264 if (srcTy.getRank() == 0)
3266 "expected a scalar instead of a 0-d vector as the source operand");
3269 auto destVectorType = getDestVectorType();
3270 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3272 "expected position attribute of rank no greater than dest vector rank");
3273 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3274 if (srcVectorType &&
3275 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3276 static_cast<unsigned>(destVectorType.getRank())))
3277 return emitOpError(
"expected position attribute rank + source rank to "
3278 "match dest vector rank");
3279 if (!srcVectorType &&
3280 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3282 "expected position attribute rank to match the dest vector rank");
3284 if (
auto attr = dyn_cast<Attribute>(pos)) {
3285 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3287 destVectorType.getDimSize(idx))) {
3288 return emitOpError(
"expected position attribute #")
3290 <<
" to be a non-negative integer smaller than the "
3292 "dest vector dimension";
3305 assert(positions.size() <= completePositions.size() &&
3306 "positions size must be less than or equal to destTy rank");
3307 copy(positions, completePositions.begin());
3322 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3323 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3324 srcVecType.getNumElements())
3327 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3340 Value splat = getScalarSplatSource(op.getValueToStore());
3341 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3369 class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3375 VectorType destTy = op.getDestVectorType();
3376 if (destTy.isScalable())
3380 if (
auto insertOp = dyn_cast<InsertOp>(user))
3381 if (insertOp.getDest() == op.getResult())
3384 InsertOp currentOp = op;
3388 if (currentOp.hasDynamicPosition())
3391 chainInsertOps.push_back(currentOp);
3392 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3395 if (currentOp && !currentOp->hasOneUse())
3399 int64_t vectorSize = destTy.getNumElements();
3400 int64_t initializedCount = 0;
3406 for (
auto insertOp : chainInsertOps) {
3408 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3412 int64_t insertBeginPosition =
3417 int64_t insertSize = 1;
3418 if (
auto srcVectorType =
3419 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3420 insertSize = srcVectorType.getNumElements();
3422 assert(insertBeginPosition + insertSize <= vectorSize &&
3423 "insert would overflow the vector");
3425 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3426 insertBeginPosition + insertSize)) {
3427 if (initializedDestIdxs[index])
3429 initializedDestIdxs[index] =
true;
3435 pendingInsertPos.push_back(insertBeginPosition);
3436 pendingInsertSize.push_back(insertSize);
3437 pendingInsertValues.push_back(insertOp.getValueToStore());
3439 if (initializedCount == vectorSize)
3444 if (initializedCount != vectorSize)
3448 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3449 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3450 pendingInsertValues))) {
3451 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3453 if (!srcVectorType) {
3454 elements[insertBeginPosition] = valueToStore;
3459 srcVectorType.getElementType());
3461 auto elementsToInsert = vector::ToElementsOp::create(
3462 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3463 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3464 elements[insertBeginPosition + linearIdx] =
3465 elementsToInsert.getResult(linearIdx);
3479 int64_t maxVectorSizeFoldThreshold) {
3480 if (insertOp.hasDynamicPosition())
3483 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3491 VectorType destTy = insertOp.getDestVectorType();
3492 if (destTy.isScalable())
3496 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3497 !insertOp->hasOneUse())
3501 int64_t insertBeginPosition =
3504 Type destEltType = destTy.getElementType();
3508 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3509 for (
auto value : denseSource.getValues<
Attribute>())
3510 if (
auto intAttr = dyn_cast<IntegerAttr>(value))
3513 insertedValues.push_back(value);
3515 if (
auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
3518 insertedValues.push_back(srcAttr);
3521 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3522 copy(insertedValues, allValues.begin() + insertBeginPosition);
3531 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3535 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3538 insertOp.
setOperand(1, destInsert.getDest());
3539 return insertOp.getResult();
3544 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3545 InsertChainFullyInitialized>(context);
3551 constexpr int64_t vectorSizeFoldThreshold = 256;
3555 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3556 return getValueToStore();
3566 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3569 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3570 vectorSizeFoldThreshold)) {
3574 return inplaceFolded;
3596 template <
typename OpType>
3598 ArrayAttr arrayAttr,
3600 StringRef attrName) {
3601 if (arrayAttr.size() > shape.size())
3602 return op.emitOpError(
"expected ")
3603 << attrName <<
" attribute of rank no greater than vector rank";
3610 template <
typename OpType>
3611 static LogicalResult
3613 int64_t
max, StringRef attrName,
3614 bool halfOpen =
true) {
3615 for (
auto attr : arrayAttr) {
3616 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3620 if (val < min || val >= upper)
3621 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3622 <<
min <<
", " << upper <<
")";
3630 template <
typename OpType>
3631 static LogicalResult
3634 bool halfOpen =
true, int64_t
min = 0) {
3635 for (
auto [index, attrDimPair] :
3637 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3638 int64_t
max = std::get<1>(attrDimPair);
3641 if (val < min || val >=
max)
3642 return op.emitOpError(
"expected ")
3643 << attrName <<
" dimension " << index <<
" to be confined to ["
3644 <<
min <<
", " <<
max <<
")";
3654 template <
typename OpType>
3656 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3658 bool halfOpen =
true, int64_t
min = 1) {
3659 assert(arrayAttr1.size() <= shape.size());
3660 assert(arrayAttr2.size() <= shape.size());
3661 for (
auto [index, it] :
3663 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3664 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3665 int64_t
max = std::get<2>(it);
3668 if (val1 + val2 < 0 || val1 + val2 >=
max)
3669 return op.emitOpError(
"expected sum(")
3670 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3671 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3678 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3685 auto sourceVectorType = getSourceVectorType();
3686 auto destVectorType = getDestVectorType();
3687 auto offsets = getOffsetsAttr();
3688 auto strides = getStridesAttr();
3689 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3691 "expected offsets of same size as destination vector rank");
3692 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3693 return emitOpError(
"expected strides of same size as source vector rank");
3694 if (sourceVectorType.getRank() > destVectorType.getRank())
3696 "expected source rank to be no greater than destination rank");
3698 auto sourceShape = sourceVectorType.getShape();
3699 auto destShape = destVectorType.getShape();
3701 destShape.size() - sourceShape.size(), 0);
3702 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3703 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3704 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3713 offName,
"source vector shape",
3717 unsigned rankDiff = destShape.size() - sourceShape.size();
3718 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3719 if (sourceVectorType.getScalableDims()[idx] !=
3720 destVectorType.getScalableDims()[idx + rankDiff]) {
3721 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3724 if (sourceVectorType.getScalableDims()[idx]) {
3725 auto sourceSize = sourceShape[idx];
3726 auto destSize = destShape[idx + rankDiff];
3727 if (sourceSize != destSize) {
3728 return emitOpError(
"expected size at idx=")
3730 << (
" to match the corresponding base size from the input "
3732 << sourceSize << (
" vs ") << destSize << (
")");
3742 class FoldInsertStridedSliceSplat final
3747 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3750 auto dst = insertStridedSliceOp.getDest();
3751 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3752 if (!splat || getScalarSplatSource(dst) != splat)
3755 rewriter.
replaceOp(insertStridedSliceOp, dst);
3762 class FoldInsertStridedSliceOfExtract final
3767 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3769 auto extractStridedSliceOp =
3770 insertStridedSliceOp.getValueToStore()
3771 .getDefiningOp<vector::ExtractStridedSliceOp>();
3773 if (!extractStridedSliceOp)
3776 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3780 if (extractStridedSliceOp.getStrides() !=
3781 insertStridedSliceOp.getStrides() ||
3782 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3785 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3792 class InsertStridedSliceConstantFolder final
3799 static constexpr int64_t vectorSizeFoldThreshold = 256;
3810 VectorType destTy = destVector.getType();
3811 if (destTy.isScalable())
3815 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3816 !destVector.hasOneUse())
3825 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3829 if (op.hasNonUnitStrides())
3832 VectorType sliceVecTy = sourceValue.getType();
3834 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3844 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3845 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3846 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3847 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3850 currDestPosition.begin() + rankDifference, currDestPosition.end());
3854 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3855 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3856 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3857 "Invalid slice element");
3858 newValues[linearizedPosition] = *sliceValuesIt;
3871 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3873 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3874 InsertStridedSliceConstantFolder>(context);
3877 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3878 if (getSourceVectorType() == getDestVectorType())
3879 return getValueToStore();
3895 p <<
" " << getLhs() <<
", " << getRhs();
3897 p <<
", " << getAcc();
3900 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3911 if (operandsInfo.size() < 2)
3913 "expected at least 2 operands");
3914 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3915 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3918 "expected vector type for operand #1");
3923 vRHS.getScalableDims()[0]};
3925 vLHS.getElementType(), scalableDimsRes);
3929 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3935 OuterProductOp::getKindAttrName(result.
name),
3937 OuterProductOp::getDefaultKind()));
3943 (operandsInfo.size() > 2 &&
3949 Type tRHS = getOperandTypeRHS();
3950 VectorType vLHS = getOperandVectorTypeLHS(),
3951 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3952 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3954 if (vLHS.getRank() != 1)
3955 return emitOpError(
"expected 1-d vector for operand #1");
3959 if (vRHS.getRank() != 1)
3960 return emitOpError(
"expected 1-d vector for operand #2");
3961 if (vRES.getRank() != 2)
3962 return emitOpError(
"expected 2-d vector result");
3963 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3964 return emitOpError(
"expected #1 operand dim to match result dim #1");
3965 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3966 return emitOpError(
"expected #2 operand dim to match result dim #2");
3967 if (vLHS.isScalable() && !vRHS.isScalable()) {
3971 "expected either both or only #2 operand dim to be scalable");
3975 if (vRES.getRank() != 1)
3976 return emitOpError(
"expected 1-d vector result");
3977 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3978 return emitOpError(
"expected #1 operand dim to match result dim #1");
3981 if (vACC && vACC != vRES)
3982 return emitOpError(
"expected operand #3 of same type as result type");
3986 return emitOpError(
"unsupported outerproduct type");
3995 Type OuterProductOp::getExpectedMaskType() {
3996 auto vecType = this->getResultVectorType();
3999 vecType.getScalableDims());
4011 ArrayAttr offsets, ArrayAttr sizes,
4012 ArrayAttr strides) {
4013 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4015 shape.reserve(vectorType.getRank());
4017 for (
unsigned e = offsets.size(); idx < e; ++idx)
4018 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4019 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4020 shape.push_back(vectorType.getShape()[idx]);
4023 vectorType.getScalableDims());
4036 offsetsAttr, sizesAttr, stridesAttr));
4037 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
4041 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
4046 auto type = getSourceVectorType();
4047 auto offsets = getOffsetsAttr();
4048 auto sizes = getSizesAttr();
4049 auto strides = getStridesAttr();
4050 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4052 "expected offsets, sizes and strides attributes of same size");
4054 auto shape = type.getShape();
4055 auto offName = getOffsetsAttrName();
4056 auto sizesName = getSizesAttrName();
4057 auto stridesName = getStridesAttrName();
4073 shape, offName, sizesName,
4078 offsets, sizes, strides);
4079 if (getResult().
getType() != resultType)
4080 return emitOpError(
"expected result type to be ") << resultType;
4082 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4083 if (type.getScalableDims()[idx]) {
4084 auto inputDim = type.getShape()[idx];
4085 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4086 if (inputDim != inputSize)
4087 return emitOpError(
"expected size at idx=")
4089 << (
" to match the corresponding base size from the input "
4091 << inputSize << (
" vs ") << inputDim << (
")");
4101 static LogicalResult
4104 auto getElement = [](ArrayAttr array,
int idx) {
4105 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4107 ArrayAttr extractOffsets = op.getOffsets();
4109 ArrayAttr extractSizes = op.getSizes();
4110 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4112 if (op.getSourceVectorType().getRank() !=
4113 insertOp.getSourceVectorType().getRank())
4115 ArrayAttr insertOffsets = insertOp.getOffsets();
4116 ArrayAttr insertStrides = insertOp.getStrides();
4119 if (extractOffsets.size() > insertOffsets.size())
4121 bool patialoverlap =
false;
4122 bool disjoint =
false;
4124 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4125 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4127 int64_t start = getElement(insertOffsets, dim);
4128 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4129 int64_t offset = getElement(extractOffsets, dim);
4130 int64_t size = getElement(extractSizes, dim);
4132 if (start <= offset && offset < end) {
4135 if (offset + size > end)
4136 patialoverlap =
true;
4137 offsetDiffs.push_back(offset - start);
4144 if (!disjoint && !patialoverlap) {
4145 op.setOperand(insertOp.getValueToStore());
4154 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4169 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4174 if (op.hasNonUnitStrides())
4177 VectorType sourceVecTy = op.getSourceVectorType();
4181 VectorType sliceVecTy = op.getType();
4183 int64_t rank = sliceVecTy.getRank();
4195 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4197 sliceValues.reserve(sliceVecTy.getNumElements());
4200 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
4201 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4203 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4204 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4206 assert(
static_cast<int64_t
>(sliceValues.size()) ==
4207 sliceVecTy.getNumElements() &&
4208 "Invalid number of slice elements");
4212 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4213 if (getSourceVectorType() == getResult().
getType())
4220 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4245 class StridedSliceCreateMaskFolder final
4250 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4252 Location loc = extractStridedSliceOp.getLoc();
4256 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4260 if (extractStridedSliceOp.hasNonUnitStrides())
4273 sliceMaskDimSizes.reserve(maskDimSizes.size());
4277 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4278 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4282 IntegerAttr offsetAttr =
4284 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4285 Value sliceMaskDimSize =
4286 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4287 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4292 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4296 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4304 class StridedSliceConstantMaskFolder final
4309 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4313 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4314 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4315 if (!constantMaskOp)
4318 if (extractStridedSliceOp.hasNonUnitStrides())
4331 sliceMaskDimSizes.reserve(maskDimSizes.size());
4332 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4333 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4334 int64_t sliceMaskDimSize =
std::max(
4335 static_cast<int64_t
>(0),
4336 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4337 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4340 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4341 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4342 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4345 if (llvm::is_contained(sliceMaskDimSizes, 0))
4346 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4351 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4359 class StridedSliceBroadcast final
4371 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4372 auto dstVecType = llvm::cast<VectorType>(op.getType());
4373 unsigned dstRank = dstVecType.getRank();
4374 unsigned rankDiff = dstRank - srcRank;
4378 bool needsSlice =
false;
4379 for (
unsigned i = 0; i < srcRank; i++) {
4380 if (srcVecType.getDimSize(i) != 1 &&
4381 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4392 for (
unsigned i = 0; i < srcRank; i++) {
4393 if (srcVecType.getDimSize(i) == 1) {
4401 source = ExtractStridedSliceOp::create(
4402 rewriter, op->getLoc(), source, offsets, sizes,
4411 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4418 Value splat = getScalarSplatSource(op.getSource());
4442 class ContiguousExtractStridedSliceToExtract final
4449 if (op.hasNonUnitStrides())
4451 Value source = op.getOperand();
4452 auto sourceType = cast<VectorType>(source.
getType());
4453 if (sourceType.isScalable() || sourceType.getRank() == 0)
4462 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4463 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4470 if (numOffsets == 0)
4475 if (numOffsets == sourceType.getRank() &&
4476 static_cast<int>(sizes.size()) == sourceType.getRank())
4480 for (
int i = 0; i < numOffsets; ++i) {
4488 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4489 sizes[numOffsets] == 1) {
4494 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4495 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4504 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4508 results.
add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4509 StridedSliceBroadcast, StridedSliceSplat,
4510 ContiguousExtractStridedSliceToExtract>(context);
4519 VectorType vectorType,
Value source,
4520 ValueRange indices, std::optional<Value> padding,
4521 AffineMapAttr permutationMapAttr,
4522 ArrayAttr inBoundsAttr) {
4524 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4526 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4527 build(builder, result, vectorType, source, indices, permutationMapAttr,
4528 *padding,
Value(), inBoundsAttr);
4533 VectorType vectorType,
Value source,
4534 ValueRange indices, std::optional<Value> padding,
4538 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4542 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4544 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4545 build(builder, result, vectorType, source, indices, *padding,
4546 permutationMapAttr, inBoundsAttr);
4551 VectorType vectorType,
Value source,
4552 ValueRange indices, std::optional<Value> padding,
4555 llvm::cast<ShapedType>(source.
getType()), vectorType);
4557 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4561 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4563 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4564 build(builder, result, vectorType, source, indices, permutationMapAttr,
4566 Value(), inBoundsAttr);
4569 template <
typename EmitFun>
4571 EmitFun emitOpError) {
4573 for (
auto expr : permutationMap.
getResults()) {
4574 auto dim = dyn_cast<AffineDimExpr>(expr);
4575 auto zero = dyn_cast<AffineConstantExpr>(expr);
4577 if (zero.getValue() != 0) {
4579 "requires a projected permutation_map (at most one dim or the zero "
4580 "constant can appear in each result)");
4585 return emitOpError(
"requires a projected permutation_map (at most one "
4586 "dim or the zero constant can appear in each result)");
4588 if (seen[dim.getPosition()]) {
4590 "requires a permutation_map that is a permutation (found one dim "
4591 "used more than once)");
4593 seen[dim.getPosition()] =
true;
4598 static LogicalResult
4600 VectorType vectorType, VectorType maskType,
4601 VectorType inferredMaskType,
AffineMap permutationMap,
4602 ArrayAttr inBounds) {
4603 if (op->hasAttr(
"masked")) {
4604 return op->emitOpError(
"masked attribute has been removed. "
4605 "Use in_bounds instead.");
4608 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4609 return op->emitOpError(
4610 "requires source to be a memref or ranked tensor type");
4612 auto elementType = shapedType.getElementType();
4613 DataLayout dataLayout = DataLayout::closest(op);
4614 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4616 unsigned sourceVecSize =
4618 vectorElementType.getShape().back();
4619 unsigned resultVecSize =
4621 vectorType.getShape().back();
4622 if (resultVecSize % sourceVecSize != 0)
4623 return op->emitOpError(
4624 "requires the bitwidth of the minor 1-D vector to be an integral "
4625 "multiple of the bitwidth of the minor 1-D vector of the source");
4627 unsigned sourceVecEltRank = vectorElementType.getRank();
4628 unsigned resultVecRank = vectorType.getRank();
4629 if (sourceVecEltRank > resultVecRank)
4630 return op->emitOpError(
4631 "requires source vector element and vector result ranks to match.");
4632 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4635 return op->emitOpError(
"requires a permutation_map with result dims of "
4636 "the same rank as the vector type");
4639 return op->emitOpError(
"does not support masks with vector element type");
4642 unsigned minorSize =
4643 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4644 unsigned resultVecSize =
4647 return op->emitOpError(
4648 "requires the bitwidth of the minor 1-D vector to be an integral "
4649 "multiple of the bitwidth of the source element type");
4653 return op->emitOpError(
"requires a permutation_map with result dims of "
4654 "the same rank as the vector type");
4658 return op->emitOpError(
"requires permutation_map without symbols");
4660 if (permutationMap.
getNumInputs() != shapedType.getRank())
4661 return op->emitOpError(
"requires a permutation_map with input dims of the "
4662 "same rank as the source type");
4664 if (maskType && maskType != inferredMaskType)
4665 return op->emitOpError(
"inferred mask type (")
4666 << inferredMaskType <<
") and mask operand type (" << maskType
4669 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4670 return op->emitOpError(
"expects the in_bounds attr of same rank "
4671 "as permutation_map results: ")
4673 <<
" vs inBounds of size: " << inBounds.size();
4680 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4681 if (op.getPermutationMap().isMinorIdentity())
4682 elidedAttrs.push_back(op.getPermutationMapAttrName());
4684 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4685 elidedAttrs.push_back(op.getInBoundsAttrName());
4692 p <<
", " << getMask();
4701 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4706 if (maskShape.empty())
4707 maskShape.push_back(1);
4729 if (hasMask.succeeded()) {
4736 if (types.size() != 2)
4737 return parser.
emitError(typesLoc,
"requires two types");
4739 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4740 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4741 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4742 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4744 return parser.
emitError(typesLoc,
"requires vector type");
4745 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4749 if (shapedType.getRank() <
4752 "expected a custom permutation_map when "
4753 "rank(source) != rank(destination)");
4757 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4759 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4761 if (!inBoundsAttr) {
4771 if (hasMask.succeeded()) {
4772 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4774 maskInfo.
location,
"does not support masks with vector element type");
4777 "expected the same rank for the vector and the "
4778 "results of the permutation map");
4786 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4788 {1, static_cast<int32_t>(indexInfo.size()), 1,
4789 static_cast<int32_t>(hasMask.succeeded())}));
4795 ShapedType shapedType = getShapedType();
4797 VectorType maskType = getMaskType();
4798 auto paddingType = getPadding().getType();
4799 auto permutationMap = getPermutationMap();
4800 VectorType inferredMaskType =
4803 auto sourceElementType = shapedType.getElementType();
4805 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4806 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4809 shapedType, vectorType, maskType,
4810 inferredMaskType, permutationMap, getInBounds())))
4813 if (
auto sourceVectorElementType =
4814 llvm::dyn_cast<VectorType>(sourceElementType)) {
4817 if (sourceVectorElementType != paddingType)
4819 "requires source element type and padding type to match.");
4823 if (!VectorType::isValidElementType(paddingType))
4824 return emitOpError(
"requires valid padding vector elemental type");
4827 if (paddingType != sourceElementType)
4829 "requires formal padding and source of the same elemental type");
4833 [&](Twine t) {
return emitOpError(t); });
4840 Type TransferReadOp::getExpectedMaskType() {
4848 return cast<VectorType>(getVector().
getType());
4851 template <
typename TransferOp>
4852 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4855 if (op.getShapedType().isDynamicDim(indicesIdx))
4857 Value index = op.getIndices()[indicesIdx];
4859 if (!cstOp.has_value())
4862 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4863 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4865 return cstOp.value() + vectorSize <= sourceSize;
4868 template <
typename TransferOp>
4872 if (op.getTransferRank() == 0)
4877 newInBounds.reserve(op.getTransferRank());
4882 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4884 if (op.isDimInBounds(i)) {
4885 newInBounds.push_back(
true);
4890 bool inBounds =
false;
4891 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4894 dimExpr.getPosition());
4895 nonBcastDims.push_back(i);
4898 newInBounds.push_back(inBounds);
4906 bool allNonBcastDimsInBounds = llvm::all_of(
4907 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4908 if (allNonBcastDimsInBounds) {
4911 newInBounds[idx] =
true;
4923 template <
typename TransferOp>
4925 auto mask = op.getMask();
4932 op.getMaskMutable().clear();
4946 static Value foldRAW(TransferReadOp readOp) {
4947 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4949 auto defWrite = readOp.getBase().
getDefiningOp<vector::TransferWriteOp>();
4952 return defWrite.getVector();
4954 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4955 cast<VectorTransferOpInterface>(readOp.getOperation())))
4957 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4963 if (
Value vec = foldRAW(*
this))
4977 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4981 void TransferReadOp::getEffects(
4984 if (llvm::isa<MemRefType>(getShapedType()))
4990 if (hasPureTensorSemantics())
5018 struct TransferReadAfterWriteToBroadcast
5024 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5028 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5031 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5035 if (readOp.getTransferChunkAccessed() !=
5036 defWrite.getTransferChunkAccessed())
5055 if (readOp.getMask() || defWrite.getMask())
5058 if (readOp.getIndices() != defWrite.getIndices())
5061 Value vec = defWrite.getVector();
5081 broadcastShape[pos.value()] = destShape[pos.index()];
5082 broadcastScalableFlags[pos.value()] =
5083 readOp.getVectorType().getScalableDims()[pos.index()];
5086 broadcastShape, defWrite.getVectorType().getElementType(),
5087 broadcastScalableFlags);
5088 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5099 results.
add<TransferReadAfterWriteToBroadcast>(context);
5109 AffineMapAttr permutationMapAttr,
5111 ArrayAttr inBoundsAttr) {
5112 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5113 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5114 mask, inBoundsAttr);
5120 AffineMapAttr permutationMapAttr,
5121 ArrayAttr inBoundsAttr) {
5122 build(builder, result, vector, dest, indices, permutationMapAttr,
5123 Value(), inBoundsAttr);
5134 (inBounds && !inBounds.value().empty())
5137 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5138 build(builder, result, vector, dest, indices, permutationMapAttr,
5139 Value(), inBoundsAttr);
5147 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5149 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5150 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5166 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5171 if (types.size() != 2)
5172 return parser.
emitError(typesLoc,
"requires two types");
5174 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5176 return parser.
emitError(typesLoc,
"requires vector type");
5177 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5178 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5179 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5180 auto permMapAttrName =
5181 TransferWriteOp::getPermutationMapAttrName(result.
name);
5185 if (shapedType.getRank() <
5188 "expected a custom permutation_map when "
5189 "rank(source) != rank(destination)");
5193 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5195 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
5197 if (!inBoundsAttr) {
5206 if (hasMask.succeeded()) {
5207 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5209 maskInfo.
location,
"does not support masks with vector element type");
5212 "expected the same rank for the vector and the "
5213 "results of the permutation map");
5219 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5221 {1, 1, static_cast<int32_t>(indexInfo.size()),
5222 static_cast<int32_t>(hasMask.succeeded())}));
5223 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5230 p <<
", " << getMask();
5237 ShapedType shapedType = getShapedType();
5239 VectorType maskType = getMaskType();
5240 auto permutationMap = getPermutationMap();
5241 VectorType inferredMaskType =
5245 if (llvm::size(
getIndices()) != shapedType.getRank())
5246 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5250 if (hasBroadcastDim())
5251 return emitOpError(
"should not have broadcast dimensions");
5254 shapedType, vectorType, maskType,
5255 inferredMaskType, permutationMap, getInBounds())))
5259 [&](Twine t) {
return emitOpError(t); });
5268 Type TransferWriteOp::getExpectedMaskType() {
5275 Value TransferWriteOp::getVector() {
return getOperand(0); }
5277 return cast<VectorType>(getValueToStore().
getType());
5300 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5304 if (write.getTransferRank() == 0)
5306 auto rankedTensorType =
5307 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5309 if (!rankedTensorType)
5312 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5316 if (read.getTransferRank() == 0)
5319 if (!read.getPermutationMap().isMinorIdentity() ||
5320 !write.getPermutationMap().isMinorIdentity())
5323 if (read.getTransferRank() != write.getTransferRank())
5326 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5329 if (read.getBase().getType() != rankedTensorType)
5332 if (read.getVectorType() != write.getVectorType())
5335 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5338 auto isNotConstantZero = [](
Value v) {
5340 return !cstOp.has_value() || cstOp.value() != 0;
5342 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5343 llvm::any_of(write.getIndices(), isNotConstantZero))
5346 results.push_back(read.getBase());
5350 static bool checkSameValueWAR(vector::TransferReadOp read,
5351 vector::TransferWriteOp write) {
5352 return read.getBase() == write.getBase() &&
5353 read.getIndices() == write.getIndices() &&
5354 read.getPermutationMap() == write.getPermutationMap() &&
5355 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5372 static LogicalResult foldWAR(TransferWriteOp write,
5374 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5376 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5380 if (!checkSameValueWAR(read, write))
5382 results.push_back(read.getBase());
5386 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5388 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5390 if (succeeded(foldWAR(*
this, results)))
5402 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5406 void TransferWriteOp::getEffects(
5409 if (llvm::isa<MemRefType>(getShapedType()))
5415 if (hasPureTensorSemantics())
5450 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5452 vector::TransferWriteOp writeToModify = writeOp;
5454 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5458 writeToModify.getBaseMutable().assign(defWrite.getBase());
5463 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5464 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5468 if (!defWrite->hasOneUse())
5470 writeToModify = defWrite;
5471 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5500 struct SwapExtractSliceOfTransferWrite
5507 if (!insertOp.hasUnitStride())
5510 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5511 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5513 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5514 if (!transferOp || !transferOp->hasOneUse())
5519 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5521 "use-def chain is rank-reducing");
5525 if (!extractOp.hasZeroOffset()) {
5527 "ExtractSliceOp has non-zero offset");
5531 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5535 "TranferWriteOp has non-zero offset");
5539 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5541 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5544 for (
auto [insertSize, extractSize] :
5545 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5548 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5553 assert(transferOp.getVectorType().hasStaticShape() &&
5554 "expected vector to have a static shape");
5557 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5558 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5560 insertOp,
"TransferWriteOp may not write the full tensor.");
5566 auto newExtractOp = tensor::ExtractSliceOp::create(
5567 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5568 insertOp.getDest(), insertOp.getMixedOffsets(),
5569 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5570 auto newTransferWriteOp = TransferWriteOp::create(
5571 rewriter, transferOp.getLoc(), transferOp.getVector(),
5572 newExtractOp.getResult(), transferOp.getIndices(),
5573 transferOp.getPermutationMapAttr(),
5576 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5586 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5593 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5595 MemRefType memRefTy) {
5598 if (!vecTy.isScalable() &&
5599 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5602 if (!memRefTy.isLastDimUnitStride())
5603 return op->
emitOpError(
"most minor memref dim must have unit stride");
5611 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5614 if (memRefTy.getRank() < resVecTy.getRank())
5616 "destination memref has lower rank than the result vector");
5619 Type memElemTy = memRefTy.getElementType();
5620 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5621 if (memVecTy != resVecTy)
5622 return emitOpError(
"base memref and result vector types should match");
5623 memElemTy = memVecTy.getElementType();
5626 if (resVecTy.getElementType() != memElemTy)
5627 return emitOpError(
"base and result element types should match");
5628 if (llvm::size(
getIndices()) != memRefTy.getRank())
5629 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5639 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5651 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5654 if (memRefTy.getRank() < valueVecTy.getRank())
5655 return emitOpError(
"source memref has lower rank than the vector to store");
5658 Type memElemTy = memRefTy.getElementType();
5659 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5660 if (memVecTy != valueVecTy)
5662 "base memref and valueToStore vector types should match");
5663 memElemTy = memVecTy.getElementType();
5666 if (valueVecTy.getElementType() != memElemTy)
5667 return emitOpError(
"base and valueToStore element type should match");
5668 if (llvm::size(
getIndices()) != memRefTy.getRank())
5669 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5673 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5678 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5687 VectorType maskVType = getMaskVectorType();
5688 VectorType passVType = getPassThruVectorType();
5692 if (resVType.getElementType() != memType.getElementType())
5693 return emitOpError(
"base and result element type should match");
5694 if (llvm::size(
getIndices()) != memType.getRank())
5695 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5696 if (resVType.getShape() != maskVType.getShape())
5697 return emitOpError(
"expected result shape to match mask shape");
5698 if (resVType != passVType)
5699 return emitOpError(
"expected pass_thru of same type as result type");
5712 load, load.getType(), load.getBase(), load.getIndices());
5715 rewriter.
replaceOp(load, load.getPassThru());
5720 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5727 results.
add<MaskedLoadFolder>(context);
5741 VectorType maskVType = getMaskVectorType();
5745 if (valueVType.getElementType() != memType.getElementType())
5746 return emitOpError(
"base and valueToStore element type should match");
5747 if (llvm::size(
getIndices()) != memType.getRank())
5748 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5749 if (valueVType.getShape() != maskVType.getShape())
5750 return emitOpError(
"expected valueToStore shape to match mask shape");
5763 store, store.getValueToStore(), store.getBase(), store.getIndices());
5771 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5778 results.
add<MaskedStoreFolder>(context);
5781 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5791 VectorType indVType = getIndexVectorType();
5792 VectorType maskVType = getMaskVectorType();
5794 ShapedType baseType = getBaseType();
5796 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5797 return emitOpError(
"requires base to be a memref or ranked tensor type");
5799 if (resVType.getElementType() != baseType.getElementType())
5800 return emitOpError(
"base and result element type should match");
5801 if (llvm::size(getOffsets()) != baseType.getRank())
5802 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5803 if (resVType.getShape() != indVType.getShape())
5804 return emitOpError(
"expected result dim to match indices dim");
5805 if (resVType.getShape() != maskVType.getShape())
5806 return emitOpError(
"expected result dim to match mask dim");
5807 if (resVType != getPassThruVectorType())
5808 return emitOpError(
"expected pass_thru of same type as result type");
5816 Type GatherOp::getExpectedMaskType() {
5817 auto vecType = this->getIndexVectorType();
5820 vecType.getScalableDims());
5823 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5828 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5829 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5830 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5841 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5854 rewriter.
replaceOp(gather, gather.getPassThru());
5859 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5870 if (!isa<MemRefType>(op.getBase().getType()))
5873 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
5877 op.getOffsets(), op.getMask(),
5886 results.
add<GatherFolder, FoldContiguousGather>(context);
5894 VectorType indVType = getIndexVectorType();
5895 VectorType maskVType = getMaskVectorType();
5899 if (valueVType.getElementType() != memType.getElementType())
5900 return emitOpError(
"base and valueToStore element type should match");
5901 if (llvm::size(getOffsets()) != memType.getRank())
5902 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5903 if (valueVType.getShape() != indVType.getShape())
5904 return emitOpError(
"expected valueToStore dim to match indices dim");
5905 if (valueVType.getShape() != maskVType.getShape())
5906 return emitOpError(
"expected valueToStore dim to match mask dim");
5925 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5936 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
5940 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
5948 results.
add<ScatterFolder, FoldContiguousScatter>(context);
5956 VectorType maskVType = getMaskVectorType();
5957 VectorType passVType = getPassThruVectorType();
5961 if (resVType.getElementType() != memType.getElementType())
5962 return emitOpError(
"base and result element type should match");
5963 if (llvm::size(
getIndices()) != memType.getRank())
5964 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5965 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5966 return emitOpError(
"expected result dim to match mask dim");
5967 if (resVType != passVType)
5968 return emitOpError(
"expected pass_thru of same type as result type");
5981 expand, expand.getType(), expand.getBase(), expand.getIndices());
5984 rewriter.
replaceOp(expand, expand.getPassThru());
5989 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5996 results.
add<ExpandLoadFolder>(context);
6004 VectorType maskVType = getMaskVectorType();
6008 if (valueVType.getElementType() != memType.getElementType())
6009 return emitOpError(
"base and valueToStore element type should match");
6010 if (llvm::size(
getIndices()) != memType.getRank())
6011 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6012 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6013 return emitOpError(
"expected valueToStore dim to match mask dim");
6018 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6026 compress, compress.getValueToStore(), compress.getBase(),
6027 compress.getIndices());
6035 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6042 results.
add<CompressStoreFolder>(context);
6051 setResultRanges(getResult(), argRanges.front());
6056 VectorType sourceType = getSourceVectorType();
6057 VectorType resultType = getResultVectorType();
6060 if (sourceType.getElementType() != resultType.getElementType())
6061 return emitOpError(
"has different source and result element types");
6064 int64_t sourceNElms = sourceType.getNumElements();
6065 int64_t resultNElms = resultType.getNumElements();
6066 if (sourceNElms != resultNElms) {
6067 return emitOpError() <<
"has different number of elements at source ("
6068 << sourceNElms <<
") and result (" << resultNElms
6073 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6074 int64_t resultNScalableDims = resultType.getNumScalableDims();
6075 if (sourceNScalableDims != resultNScalableDims)
6076 return emitOpError() <<
"has different number of scalable dims at source ("
6077 << sourceNScalableDims <<
") and result ("
6078 << resultNScalableDims <<
")";
6087 static bool isOrderPreserving(TransposeOp transpose) {
6089 VectorType sourceType = transpose.getSourceVectorType();
6092 auto isNonScalableUnitDim = [&](int64_t dim) {
6093 return inShape[dim] == 1 && !inDimIsScalable[dim];
6095 int64_t current = 0;
6096 for (
auto p : permutation) {
6097 if (!isNonScalableUnitDim(p)) {
6109 VectorType resultType =
getType();
6112 if (getSource().
getType() == resultType)
6116 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6117 setOperand(precedingShapeCast.getSource());
6122 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6123 if (isOrderPreserving(transpose)) {
6124 setOperand(transpose.getVector());
6132 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6133 if (bcastOp.getSourceType() == resultType)
6134 return bcastOp.getSource();
6138 if (
auto denseAttr =
6139 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6140 return denseAttr.reshape(
getType());
6143 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6156 static VectorType trimTrailingOneDims(VectorType oldType) {
6163 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6164 newShape = newShape.drop_back(1);
6165 newScalableDims = newScalableDims.drop_back(1);
6170 if (newShape.empty()) {
6171 newShape = oldShape.take_back();
6172 newScalableDims = oldScalableDims.take_back();
6175 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6190 class ShapeCastCreateMaskFolderTrailingOneDim final
6197 Value shapeOpSrc = shapeOp->getOperand(0);
6198 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6199 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6200 if (!createMaskOp && !constantMaskOp)
6203 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6204 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6206 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6207 if (newVecType != shapeOpResTy)
6210 auto numDimsToDrop =
6211 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6218 auto maskOperands = createMaskOp.getOperands();
6219 auto numMaskOperands = maskOperands.size();
6222 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6224 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6225 if (!constant || (constant.value() != 1))
6229 maskOperands.drop_back(numDimsToDrop);
6236 if (constantMaskOp) {
6237 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6238 auto numMaskOperands = maskDimSizes.size();
6241 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6243 if (maskDimSizes[i] != 1)
6247 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6261 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6268 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6272 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6273 bool srcIsScalar = !srcVectorType;
6281 if (srcVectorType) {
6282 if (srcVectorType.getNumElements() ==
6283 shapeCastOp.getResultVectorType().getNumElements()) {
6285 shapeCastOp, shapeCastOp.getResultVectorType(),
6286 broadcastOp.getSource());
6297 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6299 BroadcastableToResult::Success) {
6301 shapeCastOp, dstVectorType, broadcastOp.getSource());
6313 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6322 auto sourceVectorType = getSourceVectorType();
6323 auto resultVectorType = getResultVectorType();
6325 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6326 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6327 return emitOpError(
"dimension size mismatch at: ") << i;
6330 DataLayout dataLayout = DataLayout::closest(*
this);
6331 auto sourceElementBits =
6333 auto resultElementBits =
6336 if (sourceVectorType.getRank() == 0) {
6337 if (sourceElementBits != resultElementBits)
6338 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6339 "types must be equal");
6340 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6341 resultElementBits * resultVectorType.getShape().back()) {
6343 "source/result bitwidth of the minor 1-D vectors must be equal");
6355 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6356 if (getResult().
getType() == otherOp.getSource().getType())
6357 return otherOp.getSource();
6359 setOperand(otherOp.getSource());
6363 Attribute sourceConstant = adaptor.getSource();
6364 if (!sourceConstant)
6367 Type srcElemType = getSourceVectorType().getElementType();
6368 Type dstElemType = getResultVectorType().getElementType();
6370 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6371 if (floatPack.isSplat()) {
6372 auto splat = floatPack.getSplatValue<FloatAttr>();
6375 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6376 uint32_t bits =
static_cast<uint32_t
>(
6377 splat.getValue().bitcastToAPInt().getZExtValue());
6379 bits = (bits << 16) | (bits & 0xffff);
6380 APInt intBits(32, bits);
6381 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6387 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6388 if (intPack.isSplat()) {
6389 auto splat = intPack.getSplatValue<IntegerAttr>();
6391 if (llvm::isa<IntegerType>(dstElemType)) {
6396 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6397 APInt intBits = splat.getValue().zext(dstBitWidth);
6400 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6401 intBits = (intBits << srcBitWidth) | intBits;
6416 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6419 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6428 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6429 VectorType vectorType =
6433 memRefType.getMemorySpace()));
6437 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6438 if (!canonicalType.getLayout().isIdentity())
6439 return emitOpError(
"expects operand to be a memref with identity layout");
6440 if (!getResultMemRefType().getLayout().isIdentity())
6441 return emitOpError(
"expects result to be a memref with identity layout");
6442 if (getResultMemRefType().getMemorySpace() !=
6444 return emitOpError(
"expects result in same memory space");
6447 auto resultType = getResultMemRefType();
6451 "expects result and operand with same underlying scalar type: ")
6453 if (extractShape(sourceType) != extractShape(resultType))
6455 "expects concatenated result and operand shapes to be equal: ")
6466 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6469 for (
unsigned i = 0; i < permutation.size(); ++i) {
6470 transposedShape[i] = vt.getShape()[permutation[i]];
6471 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6476 transposedScalableDims));
6481 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6484 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6485 return splat.reshape(getResultVectorType());
6488 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6502 if (getSourceVectorType() == getResultVectorType() &&
6503 isOrderPreserving(*
this))
6510 VectorType vectorType = getSourceVectorType();
6511 VectorType resultType = getResultVectorType();
6512 int64_t rank = resultType.getRank();
6513 if (vectorType.getRank() != rank)
6514 return emitOpError(
"vector result rank mismatch: ") << rank;
6517 int64_t size = perm.size();
6519 return emitOpError(
"transposition length mismatch: ") << size;
6522 if (ta.value() < 0 || ta.value() >= rank)
6523 return emitOpError(
"transposition index out of range: ") << ta.value();
6524 if (seen[ta.value()])
6525 return emitOpError(
"duplicate position index: ") << ta.value();
6526 seen[ta.value()] =
true;
6527 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6528 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6533 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6534 return llvm::to_vector<4>(getResultVectorType().
getShape());
6539 setResultRanges(getResult(), argRanges.front());
6545 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6555 for (
auto index : permutation2)
6556 result.push_back(permutation1[index]);
6561 vector::TransposeOp parentTransposeOp =
6562 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6563 if (!parentTransposeOp)
6567 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6570 transposeOp, transposeOp.getResult().getType(),
6571 parentTransposeOp.getVector(), permutation);
6583 Value splat = getScalarSplatSource(transposeOp.getVector());
6588 transposeOp, transposeOp.getResultVectorType(), splat);
6594 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6600 Value transposeSrc = transpOp.getVector();
6601 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6602 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6603 if (!createMaskOp && !constantMaskOp)
6611 auto maskOperands = createMaskOp.getOperands();
6616 transpOp, transpOp.getResultVectorType(), newOperands);
6621 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6625 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6631 class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6638 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6641 if (!isOrderPreserving(transposeOp))
6644 VectorType resultType = transposeOp.getType();
6651 shapeCastOp.getSource());
6681 class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
6694 "not preceded by a broadcast");
6697 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
6698 VectorType outputType = transpose.getResultVectorType();
6701 bool inputIsScalar = !inputType;
6702 if (inputIsScalar) {
6710 int64_t inputRank = inputType.getRank();
6711 int64_t outputRank = transpose.getType().getRank();
6712 int64_t deltaRank = outputRank - inputRank;
6715 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6716 bool notOne = inputShape[inputIndex] != 1;
6717 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6718 bool groupEndFound = notOne || prevNotOne;
6719 if (groupEndFound) {
6720 int high = inputIndex + deltaRank;
6724 for (
int i = low; i < high; ++i) {
6725 if (permutation[i] < low || permutation[i] >= high) {
6727 transpose,
"permutation not local to group");
6741 vector::BroadcastableToResult::Success &&
6742 "not broadcastable directly to transpose output");
6753 void vector::TransposeOp::getCanonicalizationPatterns(
6755 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6756 FoldTransposeSplat, FoldTransposeBroadcast>(context);
6765 assert(
kind == ConstantMaskKind::AllTrue ||
6766 kind == ConstantMaskKind::AllFalse);
6767 build(builder, result, type,
6768 kind == ConstantMaskKind::AllTrue
6774 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6776 if (resultType.getRank() == 0) {
6777 if (getMaskDimSizes().size() != 1)
6778 return emitError(
"array attr must have length 1 for 0-D vectors");
6779 auto dim = getMaskDimSizes()[0];
6780 if (dim != 0 && dim != 1)
6781 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6786 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6788 "must specify array attr of size equal vector result rank");
6791 auto resultShape = resultType.getShape();
6792 auto resultScalableDims = resultType.getScalableDims();
6794 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6795 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6797 "array attr of size out of bounds of vector result dimension size");
6798 if (resultScalableDims[index] && maskDimSize != 0 &&
6799 maskDimSize != resultShape[index])
6801 "only supports 'none set' or 'all set' scalable dimensions");
6805 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6806 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6807 if (anyZeros && !allZeros)
6808 return emitOpError(
"expected all mask dim sizes to be zeros, "
6809 "as a result of conjunction with zero mask dim");
6813 bool ConstantMaskOp::isAllOnesMask() {
6816 if (resultType.getRank() == 0) {
6817 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6818 return getMaskDimSizes()[0] == 1;
6820 for (
const auto [resultSize, maskDimSize] :
6821 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6822 if (maskDimSize < resultSize)
6828 OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
6832 auto createBoolSplat = [&](
bool x) {
6838 if (vectorSizes.empty()) {
6839 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
6840 return createBoolSplat(bounds[0] == 1);
6843 if (bounds == vectorSizes)
6844 return createBoolSplat(
true);
6845 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
6846 return createBoolSplat(
false);
6859 build(builder, result, type, operands);
6863 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6865 if (vectorType.getRank() == 0) {
6866 if (getNumOperands() != 1)
6868 "must specify exactly one operand for 0-D create_mask");
6869 }
else if (getNumOperands() !=
6870 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6872 "must specify an operand for each result vector dimension");
6908 VectorType maskType = createMaskOp.getVectorType();
6910 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6913 constexpr std::array<int64_t, 1> rankZeroShape{1};
6914 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6915 if (maskType.getRank() == 0) {
6916 maskTypeDimSizes = rankZeroShape;
6917 maskTypeDimScalableFlags = rankZeroScalableDims;
6923 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6928 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6930 constantDims.push_back(*intSize);
6934 if (vscaleMultiplier < maskTypeDimSizes[i])
6936 constantDims.push_back(*vscaleMultiplier);
6943 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6944 value = std::clamp<int64_t>(value, 0, maskDimSize);
6947 if (llvm::is_contained(constantDims, 0))
6948 constantDims.assign(constantDims.size(), 0);
6961 results.
add<CreateMaskFolder>(context);
6972 assert(maskRegionBuilder &&
6973 "builder callback for 'maskRegion' must be present");
6979 maskRegionBuilder(builder, maskableOp);
6986 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6994 build(builder, result, mask, maskableOp, maskRegionBuilder);
7015 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7022 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
7036 result.
types.append(resultTypes);
7042 if (parsePassthru.succeeded()) {
7043 if (resultTypes.empty())
7046 "expects a result if passthru operand is provided");
7056 p <<
" " << getMask();
7058 p <<
", " << getPassthru();
7062 Block *singleBlock = &getMaskRegion().getBlocks().
front();
7069 p <<
" : " << getMask().getType();
7070 if (getNumResults() > 0)
7071 p <<
" -> " << getResultTypes();
7078 MaskOp>::ensureTerminator(region, builder, loc);
7084 if (isa<vector::YieldOp>(block.
back()))
7093 MaskOp>::ensureTerminator(region, builder, loc);
7100 opBuilder.setInsertionPointToEnd(&block);
7101 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7106 Block &block = getMaskRegion().getBlocks().
front();
7108 return emitOpError(
"expects a terminator within the mask region");
7111 if (numMaskRegionOps > 2)
7112 return emitOpError(
"expects only one operation to mask");
7115 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7117 return emitOpError(
"expects a terminator within the mask region");
7119 if (terminator->getNumOperands() != getNumResults())
7121 "expects number of results to match mask region yielded values");
7124 if (numMaskRegionOps == 1)
7127 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7129 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7133 return emitOpError(
"expects number of results to match maskable operation "
7134 "number of results");
7136 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7137 return emitOpError(
"expects all the results from the MaskableOpInterface "
7138 "to match all the values returned by the terminator");
7140 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7142 "expects result type to match maskable operation result type");
7145 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
7146 return emitOpError(
"multiple vector results not supported");
7149 Type expectedMaskType = maskableOp.getExpectedMaskType();
7150 if (getMask().
getType() != expectedMaskType)
7151 return emitOpError(
"expects a ")
7152 << expectedMaskType <<
" mask for the maskable operation";
7155 Value passthru = getPassthru();
7157 if (!maskableOp.supportsPassthru())
7159 "doesn't expect a passthru argument for this maskable operation");
7162 return emitOpError(
"expects result when passthru argument is provided");
7165 return emitOpError(
"expects passthru type to match result type");
7185 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7187 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7190 Block *block = maskOp.getMaskBlock();
7191 auto terminator = cast<vector::YieldOp>(block->
front());
7192 if (terminator.getNumOperands() == 0) {
7198 llvm::append_range(results, terminator.getOperands());
7202 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7204 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7212 Operation *maskableOp = getMaskableOp();
7216 llvm::append_range(results, maskableOp->
getResults());
7237 if (!maskOp.isEmpty())
7240 if (!maskOp.hasPassthru())
7243 Block *block = maskOp.getMaskBlock();
7244 auto terminator = cast<vector::YieldOp>(block->
front());
7245 assert(terminator.getNumOperands() == 1 &&
7246 "expected one result when passthru is provided");
7249 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7250 terminator.getOperand(0), maskOp.getPassthru());
7258 results.
add<CanonializeEmptyMaskOp>(context);
7265 Block *block = getMaskBlock();
7269 return &block->
front();
7273 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
7280 VectorType srcType = getSourceType();
7281 VectorType initialType = getInitialValueType();
7283 int64_t srcRank = srcType.getRank();
7284 int64_t reductionDim = getReductionDim();
7285 if (reductionDim >= srcRank)
7286 return emitOpError(
"reduction dimension ")
7287 << reductionDim <<
" has to be less than " << srcRank;
7290 int64_t initialValueRank = initialType.getRank();
7291 if (initialValueRank != srcRank - 1)
7292 return emitOpError(
"initial value rank ")
7293 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7299 for (
int i = 0; i < srcRank; i++) {
7300 if (i != reductionDim)
7301 expectedShape.push_back(srcShape[i]);
7303 if (!llvm::equal(initialValueShapes, expectedShape)) {
7304 return emitOpError(
"incompatible input/initial value shapes");
7308 Type eltType = getDestType().getElementType();
7310 return emitOpError(
"unsupported reduction type ")
7311 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7320 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7321 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7322 StridedSliceConstantMaskFolder, TransposeFolder>(
7331 auto constOperand = adaptor.getInput();
7332 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
7347 splatOp.getOperand());
7353 results.
add<SplatToBroadcastPattern>(context);
7358 setResultRanges(getResult(), argRanges.front());
7363 arith::FastMathFlagsAttr fastmath,
7370 case CombiningKind::ADD:
7373 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7374 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7376 llvm_unreachable(
"invalid value types for ADD reduction");
7378 case CombiningKind::AND:
7382 case CombiningKind::MAXNUMF:
7383 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7384 "expected float values");
7385 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7387 case CombiningKind::MAXIMUMF:
7388 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7389 "expected float values");
7390 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7392 case CombiningKind::MINNUMF:
7393 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7394 "expected float values");
7395 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7397 case CombiningKind::MINIMUMF:
7398 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7399 "expected float values");
7400 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7402 case CombiningKind::MAXSI:
7406 case CombiningKind::MINSI:
7410 case CombiningKind::MAXUI:
7418 case CombiningKind::MUL:
7421 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7422 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7424 llvm_unreachable(
"invalid value types for MUL reduction");
7426 case CombiningKind::OR:
7430 case CombiningKind::XOR:
7436 assert(result &&
"unknown CombiningKind");
7446 auto resultType = cast<VectorType>(
getType());
7447 if (resultType.isScalable()) {
7450 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7451 APInt zero(bitwidth, 0);
7452 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7454 setResultRanges(getResult(), result);
7465 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7482 return MaskOp::create(builder, maskableOp->
getLoc(),
7485 return MaskOp::create(builder, maskableOp->
getLoc(),
7502 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
7503 mask, newValue, passthru);
7510 #define GET_ATTRDEF_CLASSES
7511 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7513 #define GET_OP_CLASSES
7514 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType)
Converts an IntegerAttr to have the specified type if needed.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static bool isBroadcastLike(Operation *op)
All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'br...
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A set of arbitrary-precision integers representing bounds on a given integer value.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)