41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
52 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
54 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
75 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
77 for (
bool b : denseElts.getValues<
bool>())
80 else if (!b && val <= 0)
94 auto shape = m.getType().getShape();
97 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
98 if (maskIdx < dimSize)
111 auto maskOperands = m.getOperands();
112 for (
Value operand : maskOperands) {
113 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
128 vector::YieldOp::create(builder, loc);
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
190 VectorType vectorType) {
193 if (shapedType.getRank() == 0 &&
199 shapedType.getRank(),
201 shapedType.getContext());
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
222 m_Constant<DenseElementsAttr>(&splatAttr)) ||
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
255 if (transferA.getVectorType() != transferB.getVectorType())
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
264 if (i < rankOffset) {
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
272 if (testDynamicValueUsingBounds) {
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
295 if (testDynamicValueUsingBounds) {
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (
std::abs(computeDelta.value()) >= vectorDim)
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
321 testDynamicValueUsingBounds);
331 for (
auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
334 if (posInDim < dimSize + offsetInDim)
338 posInDim = offsetInDim;
348 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
350 assert(constOp &&
"Unexpected non-constant index");
351 return constOp.value();
361 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
373 llvm::transform(foldResults, std::back_inserter(values),
375 if (
auto attr = dyn_cast<Attribute>(foldResult))
377 builder, loc, cast<IntegerAttr>(attr).getInt())
380 return cast<Value>(foldResult);
391 auto lhs = mul.getLhs();
392 auto rhs = mul.getRhs();
393 if (lhs.getDefiningOp<vector::VectorScaleOp>())
395 if (rhs.getDefiningOp<vector::VectorScaleOp>())
405 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
407 if (intAttr.getType() != expectedType)
414 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415 auto intType = dyn_cast<IntegerType>(expectedType);
419 APFloat floatVal = floatAttr.getValue();
420 APInt intVal = floatVal.bitcastToAPInt();
470 void VectorDialect::initialize() {
472 #define GET_ATTRDEF_LIST
473 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
478 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
481 addInterfaces<VectorInlinerInterface>();
483 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
486 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
488 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
498 if (isa<ub::PoisonAttrInterface>(value))
501 return arith::ConstantOp::materialize(builder, value, type, loc);
517 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
520 CombiningKind
kind) {
524 reductionDims.push_back(en.index());
525 build(builder, result,
kind, source, acc, reductionDims);
528 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
530 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
535 std::optional<SmallVector<int64_t, 4>>
536 MultiDimReductionOp::getShapeForUnroll() {
537 return llvm::to_vector<4>(getSourceVectorType().
getShape());
543 Type inferredReturnType;
544 auto sourceScalableDims = getSourceVectorType().getScalableDims();
545 for (
auto [dimIdx, dimSize] :
547 if (!llvm::any_of(getReductionDims(),
548 [dimIdx = dimIdx](int64_t reductionDimIdx) {
549 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
551 targetShape.push_back(dimSize);
552 scalableDims.push_back(sourceScalableDims[dimIdx]);
555 if (targetShape.empty())
556 inferredReturnType = getSourceVectorType().getElementType();
559 targetShape, getSourceVectorType().
getElementType(), scalableDims);
560 if (
getType() != inferredReturnType)
561 return emitOpError() <<
"destination type " <<
getType()
562 <<
" is incompatible with source type "
563 << getSourceVectorType();
569 Type MultiDimReductionOp::getExpectedMaskType() {
570 auto vecType = getSourceVectorType();
573 vecType.getScalableDims());
582 struct ElideUnitDimsInMultiDimReduction
586 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
589 for (
const auto &dim :
enumerate(shape)) {
590 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
598 if (reductionOp.isMasked()) {
600 rootOp = reductionOp.getMaskingOp();
601 mask = reductionOp.getMaskingOp().getMask();
603 rootOp = reductionOp;
606 Location loc = reductionOp.getLoc();
607 Value acc = reductionOp.getAcc();
609 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
611 VectorType newMaskType =
613 dstVecType.getScalableDims());
614 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
616 cast = vector::ShapeCastOp::create(
617 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
622 mask = vector::ExtractOp::create(rewriter, loc, mask);
623 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
628 cast,
nullptr, mask);
635 void MultiDimReductionOp::getCanonicalizationPatterns(
637 results.
add<ElideUnitDimsInMultiDimReduction>(context);
646 arith::FastMathFlags fastMathFlags) {
647 build(builder, result,
kind, vector,
Value(), fastMathFlags);
652 arith::FastMathFlags fastMathFlags) {
653 build(builder, result,
654 llvm::cast<VectorType>(vector.
getType()).getElementType(),
kind, vector,
660 int64_t rank = getSourceVectorType().getRank();
662 return emitOpError(
"unsupported reduction rank: ") << rank;
665 Type eltType = getDest().getType();
667 return emitOpError(
"unsupported reduction type '")
668 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
677 Type ReductionOp::getExpectedMaskType() {
678 auto vecType = getSourceVectorType();
681 vecType.getScalableDims());
688 case arith::AtomicRMWKind::addf:
689 case arith::AtomicRMWKind::addi:
690 return vector::ReductionOp::create(builder, vector.
getLoc(),
691 CombiningKind::ADD, vector);
692 case arith::AtomicRMWKind::mulf:
693 case arith::AtomicRMWKind::muli:
694 return vector::ReductionOp::create(builder, vector.
getLoc(),
695 CombiningKind::MUL, vector);
696 case arith::AtomicRMWKind::minimumf:
697 return vector::ReductionOp::create(builder, vector.
getLoc(),
698 CombiningKind::MINIMUMF, vector);
699 case arith::AtomicRMWKind::mins:
700 return vector::ReductionOp::create(builder, vector.
getLoc(),
701 CombiningKind::MINSI, vector);
702 case arith::AtomicRMWKind::minu:
703 return vector::ReductionOp::create(builder, vector.
getLoc(),
705 case arith::AtomicRMWKind::maximumf:
706 return vector::ReductionOp::create(builder, vector.
getLoc(),
707 CombiningKind::MAXIMUMF, vector);
708 case arith::AtomicRMWKind::maxs:
709 return vector::ReductionOp::create(builder, vector.
getLoc(),
710 CombiningKind::MAXSI, vector);
711 case arith::AtomicRMWKind::maxu:
712 return vector::ReductionOp::create(builder, vector.
getLoc(),
713 CombiningKind::MAXUI, vector);
714 case arith::AtomicRMWKind::andi:
715 return vector::ReductionOp::create(builder, vector.
getLoc(),
716 CombiningKind::AND, vector);
717 case arith::AtomicRMWKind::ori:
718 return vector::ReductionOp::create(builder, vector.
getLoc(),
719 CombiningKind::OR, vector);
728 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
729 return llvm::to_vector<4>(getSourceVectorType().
getShape());
736 LogicalResult matchAndRewrite(ReductionOp reductionOp,
741 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
744 if (maskableOp.isMasked()) {
746 rootOp = maskableOp.getMaskingOp();
747 mask = maskableOp.getMaskingOp().getMask();
749 rootOp = reductionOp;
752 auto vectorType = reductionOp.getSourceVectorType();
753 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
756 Location loc = reductionOp.getLoc();
758 mask = ExtractOp::create(rewriter, loc, mask);
759 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
761 if (
Value acc = reductionOp.getAcc())
764 reductionOp.getFastmathAttr(), mask);
774 results.
add<ElideSingleElementReduction>(context);
788 getIndexingMapsAttrName(result.
name),
792 getIteratorTypesAttrName(result.
name),
795 return IteratorTypeAttr::get(builder.getContext(), t);
801 ArrayAttr indexingMaps,
802 ArrayAttr iteratorTypes) {
803 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
804 ContractionOp::getDefaultKind());
809 ArrayAttr indexingMaps,
810 ArrayAttr iteratorTypes, CombiningKind
kind) {
827 DictionaryAttr dictAttr;
842 dictAttr.getValue().end());
848 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
850 if (!iteratorTypes) {
852 <<
"expected " << getIteratorTypesAttrName(result.
name)
853 <<
" array attribute";
858 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
859 auto maybeIteratorType = symbolizeIteratorType(s);
860 if (!maybeIteratorType.has_value())
861 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
863 iteratorTypeAttrs.push_back(
871 getKindAttrName(result.
name),
873 ContractionOp::getDefaultKind()));
875 if (masksInfo.empty())
877 if (masksInfo.size() != 2)
879 "expected zero or exactly 2 vector mask operands");
880 auto lhsType = llvm::cast<VectorType>(types[0]);
881 auto rhsType = llvm::cast<VectorType>(types[1]);
883 std::array<VectorType, 2> maskTypes = {
893 auto attrNames = getTraitAttrNames();
895 traitAttrsSet.insert_range(attrNames);
897 for (
auto attr : (*this)->getAttrs()) {
898 if (attr.getName() == getIteratorTypesAttrName()) {
900 llvm::cast<ArrayAttr>(attr.getValue())
901 .getAsValueRange<IteratorTypeAttr, IteratorType>();
907 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
911 attrs.emplace_back(getIteratorTypesAttrName(),
913 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
914 attrs.push_back(attr);
918 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
919 p << getRhs() <<
", " << getAcc();
922 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
927 const std::vector<std::pair<int64_t, int64_t>> &map) {
928 for (
auto &dimPair : map) {
929 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
930 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
931 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
938 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
940 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
941 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
944 for (
auto &dimPair : contractingDimMap) {
945 lhsContractingDimSet.insert(dimPair.first);
946 rhsContractingDimSet.insert(dimPair.second);
949 llvm::make_second_range(batchDimMap));
953 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
954 if (lhsContractingDimSet.count(i) > 0)
956 expectedResultDims.push_back(lhsType.getDimSize(i));
960 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
961 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
963 expectedResultDims.push_back(rhsType.getDimSize(i));
967 if (expectedResultDims.empty()) {
969 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
970 return op.emitOpError(
"invalid accumulator/result vector shape");
973 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
974 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
975 if (!resVectorType || !accVectorType)
976 return op.emitOpError(
"invalid accumulator/result vector shape");
982 AffineMap lhsMap = op.getIndexingMapsArray()[0];
983 AffineMap rhsMap = op.getIndexingMapsArray()[1];
985 return op.emitOpError(
986 "expected all dimensions to be either a LHS or a RHS dimension");
989 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
990 VectorType v = pair.first;
991 auto map = pair.second;
992 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
993 unsigned pos = map.getDimPosition(idx);
998 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
999 return op.emitOpError(
"expected all dimensions to get an extent as "
1000 "either a LHS or a RHS dimension");
1002 AffineMap resMap = op.getIndexingMapsArray()[2];
1007 assert(llvm::all_of(expectedMap.
getResults(),
1008 llvm::IsaPred<AffineConstantExpr>) &&
1009 "expected constant extent along all dimensions.");
1011 auto expectedShape = llvm::to_vector<4>(
1013 return cast<AffineConstantExpr>(e).getValue();
1017 resVectorType.getScalableDims());
1018 if (resVectorType != expected || accVectorType != expected)
1019 return op.emitOpError(
1020 "invalid accumulator/result vector shape, expected: ")
1027 VectorType lhsType = getLhsType();
1028 VectorType rhsType = getRhsType();
1029 Type accType = getAccType();
1030 Type resType = getResultType();
1032 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1033 if (!lhsType.getElementType().isSignlessInteger())
1034 return emitOpError(
"only supports signless integer types");
1038 if (getIndexingMapsArray().size() != 3)
1039 return emitOpError(
"expected an indexing map for each vector operand");
1044 unsigned numIterators = getIteratorTypes().getValue().size();
1046 auto index = it.index();
1047 auto map = it.value();
1048 if (map.getNumSymbols() != 0)
1049 return emitOpError(
"expected indexing map ")
1050 << index <<
" to have no symbols";
1051 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
1052 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1055 if (map.getNumDims() != numIterators)
1056 return emitOpError(
"expected indexing map ")
1057 << index <<
" to have " << numIterators <<
" number of inputs";
1058 if (map.getNumResults() != rank)
1059 return emitOpError(
"expected indexing map ")
1060 << index <<
" to have " << rank <<
" number of outputs";
1061 if (!map.isProjectedPermutation())
1062 return emitOpError(
"expected indexing map ")
1063 << index <<
" to be a projected permutation of its inputs";
1066 auto contractingDimMap = getContractingDimMap();
1067 auto batchDimMap = getBatchDimMap();
1070 if (contractingDimMap.empty())
1071 return emitOpError(
"expected at least one contracting dimension pair");
1074 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1075 return emitOpError(
"invalid contracting dimension map");
1079 return emitOpError(
"invalid batch dimension map");
1083 contractingDimMap, batchDimMap)))
1087 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1088 auto elementType = vectorType ? vectorType.getElementType() : resType;
1090 return emitOpError(
"unsupported contraction type");
1093 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1100 Type ContractionOp::getExpectedMaskType() {
1101 auto indexingMaps = this->getIndexingMapsArray();
1104 VectorType lhsType = this->getLhsType();
1105 VectorType rhsType = this->getRhsType();
1107 unsigned numVecDims = lhsIdxMap.
getNumDims();
1116 lhsType.getScalableDims()[dimIdx];
1121 rhsType.getScalableDims()[dimIdx];
1124 assert(ShapedType::isStaticShape(maskShape) &&
1125 "Mask shape couldn't be computed");
1129 maskShapeScalableDims);
1134 getIteratorTypesAttrName(), getKindAttrName()};
1144 static std::vector<std::pair<int64_t, int64_t>>
1146 IteratorType targetIteratorType,
MLIRContext *context) {
1147 std::vector<std::pair<int64_t, int64_t>> dimMap;
1149 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1150 if (iteratorType != targetIteratorType)
1156 if (lhsDim >= 0 && rhsDim >= 0)
1157 dimMap.emplace_back(lhsDim, rhsDim);
1162 void ContractionOp::getIterationBounds(
1164 auto lhsShape = getLhsType().getShape();
1165 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1170 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1171 if (iteratorType == IteratorType::reduction) {
1173 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1174 assert(lhsDimIndex >= 0);
1175 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1179 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1180 assert(resDimIndex >= 0);
1181 assert(resVectorType !=
nullptr);
1182 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1186 void ContractionOp::getIterationIndexMap(
1188 unsigned numMaps = getIndexingMapsArray().size();
1189 iterationIndexMap.resize(numMaps);
1191 auto index = it.index();
1192 auto map = it.value();
1193 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1194 auto dim = cast<AffineDimExpr>(map.getResult(i));
1195 iterationIndexMap[index][dim.getPosition()] = i;
1200 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1202 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1206 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1208 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1212 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1214 getIterationBounds(shape);
1236 template <
typename AddOpType>
1242 auto canonicalize = [&](
Value maybeContraction,
1243 Value otherOperand) -> vector::ContractionOp {
1244 vector::ContractionOp contractionOp =
1245 dyn_cast_or_null<vector::ContractionOp>(
1248 return vector::ContractionOp();
1249 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1250 contractionOp.getAcc().getDefiningOp())) {
1251 if (maybeZero.getValue() ==
1252 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1254 bvm.
map(contractionOp.getAcc(), otherOperand);
1255 auto newContraction =
1256 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1257 rewriter.
replaceOp(addOp, newContraction.getResult());
1258 return newContraction;
1261 return vector::ContractionOp();
1264 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1265 vector::ContractionOp
contract = canonicalize(a, b);
1267 return contract ? success() : failure();
1281 return index == poisonValue || (index >= 0 && index < maxIndex);
1290 setResultRanges(getResult(), argRanges.front());
1295 auto vectorTy = cast<VectorType>(source.
getType());
1300 Value source, int64_t position) {
1320 build(builder, result, source, dynamicPos,
1325 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1326 ExtractOp::Adaptor adaptor,
1328 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1329 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1330 vectorType.getRank()) {
1331 inferredReturnTypes.push_back(vectorType.getElementType());
1333 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1334 vectorType.getRank());
1336 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1337 vectorType.getScalableDims().drop_front(n)));
1345 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1346 return vectorType && vectorType.getShape().equals({1}) &&
1347 vectorType.getElementType() == r.front();
1349 if (l.size() == 1 && r.size() == 1 &&
1350 (isCompatible(l, r) || isCompatible(r, l)))
1356 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1357 if (resTy.getRank() == 0)
1359 "expected a scalar instead of a 0-d vector as the result type");
1362 auto dynamicMarkersCount =
1363 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1364 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1366 "mismatch between dynamic and static positions (kDynamic marker but no "
1367 "corresponding dynamic position) -- this can only happen due to an "
1368 "incorrect fold/rewrite");
1369 auto position = getMixedPosition();
1370 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1372 "expected position attribute of rank no greater than vector rank");
1374 if (
auto attr = dyn_cast<Attribute>(pos)) {
1375 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1377 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1378 return emitOpError(
"expected position attribute #")
1380 <<
" to be a non-negative integer smaller than the "
1381 "corresponding vector dimension or poison (-1)";
1388 template <
typename IntType>
1390 return llvm::to_vector<4>(llvm::map_range(
1391 arrayAttr.getAsRange<IntegerAttr>(),
1392 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1398 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1402 if (extractOp.hasDynamicPosition())
1406 ExtractOp currentOp = extractOp;
1408 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1409 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1412 if (currentOp.hasDynamicPosition())
1415 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1417 extractOp.setOperand(0, currentOp.getSource());
1420 std::reverse(globalPosition.begin(), globalPosition.end());
1421 extractOp.setStaticPosition(globalPosition);
1433 class ExtractFromInsertTransposeChainState {
1435 ExtractFromInsertTransposeChainState(ExtractOp e);
1444 template <
typename ContainerA,
typename ContainerB>
1445 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1446 return a.size() <= b.size() &&
1447 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1454 template <
typename ContainerA,
typename ContainerB>
1455 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1456 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1457 if (elemA < 0 || elemB < 0)
1472 void updateStateForNextIteration(
Value v) {
1479 LogicalResult handleTransposeOp();
1482 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1497 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1502 Value tryToFoldExtractOpInPlace(
Value source);
1504 ExtractOp extractOp;
1506 int64_t extractedRank;
1508 InsertOp nextInsertOp;
1509 TransposeOp nextTransposeOp;
1524 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1526 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1527 extractedRank(extractOp.getNumIndices()) {
1528 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1529 sentinels.reserve(vectorRank - extractedRank);
1530 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1531 sentinels.push_back(-(i + 1));
1533 extractOp.getStaticPosition().end());
1539 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1541 if (extractOp.hasDynamicPosition())
1544 if (!nextTransposeOp)
1547 nextTransposeOp.getPermutation(), extractOp.getContext()));
1554 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1557 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1564 res = nextInsertOp.getValueToStore();
1566 return success(canFold());
1573 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1575 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1588 res = nextInsertOp.getValueToStore();
1596 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1599 if (extractOp.hasDynamicPosition())
1603 bool nothingToFold = (source == extractOp.getSource());
1604 if (nothingToFold || !canFold())
1609 extractOp.setStaticPosition(
1611 extractOp.getSourceMutable().assign(source);
1612 return extractOp.getResult();
1616 Value ExtractFromInsertTransposeChainState::fold() {
1618 if (extractOp.hasDynamicPosition())
1621 Value valueToExtractFrom = extractOp.getSource();
1622 updateStateForNextIteration(valueToExtractFrom);
1623 while (nextInsertOp || nextTransposeOp) {
1626 if (succeeded(handleTransposeOp())) {
1627 valueToExtractFrom = nextTransposeOp.getVector();
1628 updateStateForNextIteration(valueToExtractFrom);
1634 if (succeeded(handleInsertOpWithMatchingPos(result)))
1639 if (succeeded(handleInsertOpWithPrefixPos(result)))
1640 return tryToFoldExtractOpInPlace(result);
1650 valueToExtractFrom = nextInsertOp.getDest();
1651 updateStateForNextIteration(valueToExtractFrom);
1654 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1659 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1660 auto vecType = dyn_cast<VectorType>(type);
1661 return vecType && vecType.getRank() == 0;
1671 if (isa<BroadcastOp>(op))
1674 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1682 VectorType srcType = shapeCast.getSourceVectorType();
1684 uint64_t srcRank = srcType.getRank();
1686 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1712 Operation *defOp = extractOp.getSource().getDefiningOp();
1719 if (extractOp.getType() == input.
getType())
1725 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1726 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1727 unsigned inputRank = inputType ? inputType.getRank() : 0;
1728 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1729 unsigned extractRank = extractType ? extractType.getRank() : 0;
1732 if (extractRank > inputRank)
1736 assert(inputType &&
"input must be a vector type because of previous checks");
1745 extractType.getShape() != inputShape.take_back(extractRank))
1750 unsigned deltaOverall = inputRank - extractRank;
1751 unsigned deltaBroadcast = broadcastRank - inputRank;
1755 for (
auto [i, size] :
llvm::enumerate(inputShape.take_front(deltaOverall))) {
1756 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1759 extractOp->setOperands(
1760 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1761 extractOp.setStaticPosition(staticPos);
1762 return extractOp.getResult();
1778 if (extractOp.hasDynamicPosition())
1781 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1786 if (shuffleOp.getResultVectorType().getRank() != 1)
1789 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1790 auto shuffleMask = shuffleOp.getMask();
1791 int64_t extractIdx = extractOp.getStaticPosition()[0];
1792 int64_t shuffleIdx = shuffleMask[extractIdx];
1795 if (shuffleIdx < inputVecSize) {
1796 extractOp.setOperand(0, shuffleOp.getV1());
1797 extractOp.setStaticPosition({shuffleIdx});
1799 extractOp.setOperand(0, shuffleOp.getV2());
1800 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1803 return extractOp.getResult();
1809 if (extractOp.hasDynamicPosition())
1812 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1817 auto getDimReverse = [](VectorType type, int64_t n) {
1818 return type.getShape().take_back(n + 1).front();
1820 int64_t destinationRank =
1821 llvm::isa<VectorType>(extractOp.getType())
1822 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1824 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1826 if (destinationRank > 0) {
1827 auto destinationType =
1828 llvm::cast<VectorType>(extractOp.getResult().getType());
1829 for (int64_t i = 0; i < destinationRank; i++) {
1833 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1834 getDimReverse(destinationType, i))
1841 std::reverse(extractedPos.begin(), extractedPos.end());
1844 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1845 strides.push_back(stride);
1847 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1850 int64_t position =
linearize(extractedPos, strides);
1854 int64_t numDimension =
1855 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1857 for (int64_t i = 0; i < numDimension; i++) {
1858 newStrides.push_back(stride);
1860 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1862 std::reverse(newStrides.begin(), newStrides.end());
1866 extractOp.setStaticPosition(newPosition);
1867 extractOp.setOperand(0, shapeCastOp.getSource());
1868 return extractOp.getResult();
1874 if (extractOp.hasDynamicPosition())
1877 auto extractStridedSliceOp =
1878 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1879 if (!extractStridedSliceOp)
1888 if (extractStridedSliceOp.hasNonUnitStrides())
1893 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1894 while (!sliceOffsets.empty()) {
1895 size_t lastOffset = sliceOffsets.size() - 1;
1896 if (sliceOffsets.back() != 0 ||
1897 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1898 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1900 sliceOffsets.pop_back();
1902 unsigned destinationRank = 0;
1903 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1904 destinationRank = vecType.getRank();
1907 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1908 sliceOffsets.size())
1912 assert(extractedPos.size() >= sliceOffsets.size());
1913 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1914 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1915 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1919 extractOp.setStaticPosition(extractedPos);
1920 return extractOp.getResult();
1926 if (extractOp.hasDynamicPosition())
1929 int64_t destinationRank =
1930 llvm::isa<VectorType>(extractOp.getType())
1931 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1933 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1943 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1944 insertOp.getSourceVectorType().getRank();
1945 if (destinationRank > insertOp.getSourceVectorType().getRank())
1947 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1950 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1951 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1954 bool disjoint =
false;
1956 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1957 int64_t start = insertOffsets[dim];
1959 (dim < insertRankDiff)
1961 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1962 int64_t end = start + size;
1963 int64_t offset = extractOffsets[dim];
1965 if (start <= offset && offset < end) {
1966 if (dim >= insertRankDiff)
1967 offsetDiffs.push_back(offset - start);
1977 int64_t srcRankDiff =
1978 insertOp.getSourceVectorType().getRank() - destinationRank;
1979 for (int64_t i = 0; i < destinationRank; i++) {
1980 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1981 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1985 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1988 extractOp.setStaticPosition(offsetDiffs);
1989 return extractOp.getResult();
1993 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2006 if (extractOp.hasDynamicPosition())
2010 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2011 if (!fromElementsOp)
2015 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2016 if (vecType.isScalable())
2020 int64_t rank = vecType.getRank();
2022 if (extractOp.getType() != vecType.getElementType())
2024 assert(
static_cast<int64_t
>(indices.size()) == rank &&
2025 "unexpected number of indices");
2030 for (
int i = rank - 1; i >= 0; --i) {
2031 flatIndex += indices[i] * stride;
2032 stride *= vecType.getDimSize(i);
2034 return fromElementsOp.getElements()[flatIndex];
2039 template <
typename OpType,
typename AdaptorType>
2042 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2043 OperandRange dynamicPosition = op.getDynamicPosition();
2046 if constexpr (std::is_same_v<OpType, ExtractOp>)
2047 vectorShape = op.getSourceVectorType().getShape();
2052 if (!dynamicPosition.size())
2059 bool opChange =
false;
2060 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2061 if (ShapedType::isStatic(staticPosition[i]))
2063 Attribute positionAttr = dynamicPositionAttr[index];
2064 Value position = dynamicPosition[index++];
2065 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2066 int64_t value = attr.getInt();
2070 staticPosition[i] = attr.getInt();
2075 operands.push_back(position);
2079 op.setStaticPosition(staticPosition);
2080 op.getOperation()->setOperands(operands);
2082 return op.getResult();
2091 int64_t poisonVal) {
2092 if (!is_contained(staticPos, poisonVal))
2100 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2109 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2114 if (denseAttr.isSplat()) {
2116 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2121 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2122 if (vecTy.isScalable())
2125 if (extractOp.hasDynamicPosition()) {
2140 copy(extractOp.getStaticPosition(), completePositions.begin());
2143 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2146 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2148 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2151 newAttr = *denseValuesBegin;
2161 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2172 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2178 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2193 return inplaceFolded;
2206 Operation *defOp = extractOp.getSource().getDefiningOp();
2207 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2213 BroadcastableToResult::Success)
2229 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2233 VectorType extractedMaskType =
2234 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2236 if (!extractedMaskType)
2239 auto maskOperands = createMaskOp.getOperands();
2241 VectorType maskType = createMaskOp.getVectorType();
2243 bool containsUnknownDims =
false;
2246 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2248 int64_t pos = extractOpPos[dimIdx];
2249 Value operand = maskOperands[dimIdx];
2250 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2253 containsUnknownDims =
true;
2257 int64_t createMaskBound =
2258 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2260 if (pos != ShapedType::kDynamic) {
2263 allFalse |= pos >= createMaskBound;
2264 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2268 containsUnknownDims =
true;
2275 }
else if (!containsUnknownDims) {
2277 extractOp, extractedMaskType,
2278 maskOperands.drop_front(extractOpPos.size()));
2288 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2290 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2294 VectorType sourceType = castOp.getSourceVectorType();
2295 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2299 if (sourceType.getNumElements() != targetType.getNumElements())
2303 castOp.getSource());
2313 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2316 if (extractOp.hasDynamicPosition())
2320 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2325 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2326 if (!fromElementsOp)
2328 VectorType inputType = fromElementsOp.getType();
2331 if (resultType.isScalable() || inputType.isScalable())
2337 llvm::to_vector(extractOp.getStaticPosition());
2338 firstElementPos.append(resultType.getRank(), 0);
2341 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2342 flatIndex += firstElementPos[i] * stride;
2343 stride *= inputType.getDimSize(i);
2348 extractOp, resultType,
2349 fromElementsOp.getElements().slice(flatIndex,
2350 resultType.getNumElements()));
2358 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2359 results.
add(foldExtractFromShapeCastToShapeCast);
2360 results.
add(foldExtractFromFromElements);
2365 for (
auto attr : arrayAttr)
2366 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2373 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2384 if (operands.empty())
2387 return llvm::all_of(operands, [&](
Value operand) {
2389 return currentDef == defOp;
2404 static LogicalResult
2407 auto fromElementsOp =
2408 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2409 if (!fromElementsOp)
2412 llvm::append_range(results, fromElementsOp.getElements());
2426 static LogicalResult
2429 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2433 if (isa<VectorType>(bcastOp.getSource().getType()))
2436 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2438 Value scalar = bcastOp.getSource();
2439 results.assign(resultVecType.getNumElements(), scalar);
2443 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2451 ToElementsOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
2452 ToElementsOp::Adaptor adaptor,
2454 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2455 Type elType = vecType.getElementType();
2456 inferredReturnTypes.append(vecType.getNumElements(), elType);
2478 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2483 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2487 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2492 int64_t dstRank = dstShape.size();
2493 int64_t srcRank = srcShape.size();
2496 auto srcElems = vector::ToElementsOp::create(
2497 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2499 int64_t dstCount = llvm::product_of(dstShape);
2502 replacements.reserve(dstCount);
2527 for (int64_t lin = 0; lin < dstCount; ++lin) {
2530 for (int64_t k = 0; k < srcRank; ++k)
2531 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2533 int64_t srcLin =
linearize(srcIdx, srcStrides);
2534 replacements.push_back(srcElems.getResult(srcLin));
2537 rewriter.
replaceOp(toElementsOp, replacements);
2564 OperandRange fromElemsOperands = fromElementsOp.getElements();
2565 if (fromElemsOperands.empty())
2568 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2576 Value toElementsInput = toElementsOp.getSource();
2577 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2578 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2579 return toElementsInput;
2599 if (llvm::any_of(elements, [](
Attribute attr) {
2600 return !attr || isa<ub::PoisonAttrInterface>(attr);
2605 auto destVecType = fromElementsOp.getDest().getType();
2606 auto destEltType = destVecType.getElementType();
2607 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2612 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2619 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2633 static LogicalResult
2636 if (!llvm::all_equal(fromElementsOp.getElements()))
2639 fromElementsOp, fromElementsOp.getType(),
2640 fromElementsOp.getElements().front());
2668 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2672 if (fromElements.getType().getNumElements() == 1)
2683 for (
auto [insertIndex, element] :
2687 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2690 "element not from vector.extract");
2695 if (insertIndex == 0) {
2696 source = extractOp.getSource();
2697 }
else if (extractOp.getSource() != source) {
2699 "element from different vector");
2703 int64_t rank = position.size();
2704 assert(rank == source.getType().getRank() &&
2705 "scalar extract must have full rank position");
2716 if (insertIndex == 0) {
2717 const int64_t numElms = fromElements.getType().getNumElements();
2718 int64_t numSuffixElms = 1;
2719 int64_t index = rank;
2720 while (index > 0 && position[index - 1] == 0 &&
2721 numSuffixElms < numElms) {
2722 numSuffixElms *= source.getType().getDimSize(index - 1);
2725 if (numSuffixElms != numElms) {
2727 fromElements,
"elements do not form a suffix of source");
2729 expectedPosition = llvm::to_vector(position);
2730 combinedPosition = position.drop_back(rank - index);
2734 else if (expectedPosition != position) {
2736 fromElements,
"elements not in ascending order (static order)");
2738 increment(expectedPosition, source.getType().getShape());
2741 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2742 fromElements.getLoc(), source, combinedPosition);
2745 fromElements, fromElements.getType(), extracted);
2753 for (
int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2755 if (indices[dim] < shape[dim])
2774 setResultRanges(getResult(), argRanges.front());
2777 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2778 return llvm::to_vector<4>(getResultVectorType().
getShape());
2786 int64_t rankDiff = dstShape.size() - srcShape.size();
2787 int64_t dstDim = rankDiff;
2789 for (
auto [s1, s2] :
2790 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2792 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2802 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2821 Value BroadcastOp::createOrFoldBroadcastOp(
2824 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2828 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2829 if (broadcastedDims.contains(i))
2831 checkShape.push_back(dstShape[i]);
2833 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2834 "ill-formed broadcastedDims contains values not confined to "
2839 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2843 if (!srcVectorType) {
2844 assert(checkShape.empty() &&
2845 "ill-formed createOrFoldBroadcastOp arguments");
2846 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2849 assert(srcVectorType.getShape().equals(checkShape) &&
2850 "ill-formed createOrFoldBroadcastOp arguments");
2861 broadcastShape.reserve(dstShape.size());
2877 int64_t nextSrcShapeDim = broadcastedDims.size();
2878 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2879 if (broadcastedDims.contains(i)) {
2884 broadcastShape.push_back(dstShape[i]);
2885 permutation[i] = broadcastShape.size() - 1;
2891 permutation[i] = nextSrcShapeDim++;
2895 llvm::append_range(broadcastShape, srcVectorType.getShape());
2900 "unexpected \"dim-1\" broadcast");
2902 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2904 vector::BroadcastableToResult::Success &&
2905 "must be broadcastable");
2909 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2910 if (permutation[i] != i)
2911 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2917 Type srcType, VectorType dstVectorType,
2918 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2920 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2922 return BroadcastableToResult::Success;
2924 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2926 return BroadcastableToResult::SourceTypeNotAVector;
2928 int64_t srcRank = srcVectorType.getRank();
2929 int64_t dstRank = dstVectorType.getRank();
2930 if (srcRank > dstRank)
2931 return BroadcastableToResult::SourceRankHigher;
2934 int64_t lead = dstRank - srcRank;
2935 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2938 bool foundMismatchingDims =
false;
2941 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2942 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2943 if (srcDim != 1 && srcDim != dstDim)
2944 foundMismatchingDims =
true;
2947 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2948 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2949 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2952 (srcDimScalableFlag != dstDimScalableFlag &&
2953 (srcDim != 1 || srcDimScalableFlag)))
2954 foundMismatchingDims =
true;
2956 if (foundMismatchingDims) {
2957 if (mismatchingDims !=
nullptr) {
2958 mismatchingDims->first.dim = srcDim;
2959 mismatchingDims->first.isScalable = srcDimScalableFlag;
2961 mismatchingDims->second.dim = dstDim;
2962 mismatchingDims->second.isScalable = dstDimScalableFlag;
2964 return BroadcastableToResult::DimensionMismatch;
2968 return BroadcastableToResult::Success;
2972 std::pair<VectorDim, VectorDim> mismatchingDims;
2974 getSourceType(), getResultVectorType(), &mismatchingDims);
2975 if (res == BroadcastableToResult::Success)
2977 if (res == BroadcastableToResult::SourceRankHigher)
2978 return emitOpError(
"source rank higher than destination rank");
2979 if (res == BroadcastableToResult::DimensionMismatch) {
2980 return emitOpError(
"dimension mismatch (")
2981 << (mismatchingDims.first.isScalable ?
"[" :
"")
2982 << mismatchingDims.first.dim
2983 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2984 << (mismatchingDims.second.isScalable ?
"[" :
"")
2985 << mismatchingDims.second.dim
2986 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2988 if (res == BroadcastableToResult::SourceTypeNotAVector)
2989 return emitOpError(
"source type is not a vector");
2990 llvm_unreachable(
"unexpected vector.broadcast op error");
2997 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3001 VectorType srcType = srcShapeCast.getSourceVectorType();
3002 VectorType destType = broadcastOp.getResultVectorType();
3005 BroadcastableToResult::Success)
3010 srcShapeCast.getResultVectorType().getShape();
3013 unsigned numTrailingDims =
std::min(srcShape.size(), shapecastShape.size());
3014 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3015 shapecastShape.take_back(numTrailingDims)))
3018 assert(all_of(srcShape.drop_back(numTrailingDims),
3019 [](int64_t E) { return E == 1; }) &&
3020 all_of(shapecastShape.drop_back(numTrailingDims),
3021 [](int64_t E) { return E == 1; }) &&
3022 "ill-formed shape_cast");
3024 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3029 if (getSourceType() == getResultVectorType())
3034 if (!adaptor.getSource())
3036 auto vectorType = getResultVectorType();
3037 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3038 if (vectorType.getElementType() != attr.getType())
3042 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3043 if (vectorType.getElementType() != attr.getType())
3047 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3049 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3062 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3066 broadcastOp.getResultVectorType(),
3067 srcBroadcast.getSource());
3077 results.
add<BroadcastFolder>(context);
3085 VectorType resultType = getResultVectorType();
3086 VectorType v1Type = getV1VectorType();
3087 VectorType v2Type = getV2VectorType();
3089 int64_t resRank = resultType.getRank();
3090 int64_t v1Rank = v1Type.getRank();
3091 int64_t v2Rank = v2Type.getRank();
3092 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3093 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3094 if (!wellFormed0DCase && !wellFormedNDCase)
3095 return emitOpError(
"rank mismatch");
3098 for (int64_t r = 1; r < v1Rank; ++r) {
3099 int64_t resDim = resultType.getDimSize(r);
3100 int64_t v1Dim = v1Type.getDimSize(r);
3101 int64_t v2Dim = v2Type.getDimSize(r);
3102 if (resDim != v1Dim || v1Dim != v2Dim)
3103 return emitOpError(
"dimension mismatch");
3107 int64_t maskLength = mask.size();
3108 if (maskLength <= 0)
3109 return emitOpError(
"invalid mask length");
3110 if (maskLength != resultType.getDimSize(0))
3111 return emitOpError(
"mask length mismatch");
3113 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3114 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3117 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3123 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
3124 ShuffleOp::Adaptor adaptor,
3126 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3127 auto v1Rank = v1Type.getRank();
3131 shape.reserve(v1Rank);
3132 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3135 llvm::append_range(shape, v1Type.getShape().drop_front());
3136 inferredReturnTypes.push_back(
3141 template <
typename T>
3144 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3145 return value == expected++;
3149 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3150 auto v1Type = getV1VectorType();
3151 auto v2Type = getV2VectorType();
3153 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3154 "Vector shuffle does not support scalable vectors");
3158 if (v1Type.getRank() == 0)
3162 auto mask = getMask();
3169 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3170 if (!v1Attr || !v2Attr)
3174 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3175 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3176 if (isV1Poison && isV2Poison)
3181 if (v1Type.getRank() != 1)
3190 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3193 v2Elements = to_vector(v2DenseAttr.getValues<
Attribute>());
3194 poisonElement = v2Elements[0];
3197 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3200 v1Elements = to_vector(v1DenseAttr.getValues<
Attribute>());
3201 poisonElement = v1Elements[0];
3205 int64_t v1Size = v1Type.getDimSize(0);
3206 for (int64_t maskIdx : mask) {
3209 if (maskIdx == ShuffleOp::kPoisonIndex) {
3210 indexedElm = poisonElement;
3212 if (maskIdx < v1Size)
3213 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3215 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3218 results.push_back(indexedElm);
3233 VectorType v1VectorType = shuffleOp.getV1VectorType();
3235 if (v1VectorType.getRank() > 0)
3237 if (mask.size() != 1)
3257 static Value getScalarSplatSource(
Value value) {
3263 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3270 if (isa<VectorType>(
broadcast.getSourceType()))
3284 Value splat = getScalarSplatSource(op.getV1());
3285 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3301 VectorType resultType = op.getResultVectorType();
3302 if (resultType.isScalable())
3304 op,
"ShuffleOp can't represent a scalable interleave");
3306 if (resultType.getRank() != 1)
3308 op,
"ShuffleOp can't represent an n-D interleave");
3310 VectorType sourceType = op.getV1VectorType();
3311 if (sourceType != op.getV2VectorType() ||
3312 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3314 op,
"ShuffleOp types don't match an interleave");
3318 int64_t resultVectorSize = resultType.getNumElements();
3319 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3320 int64_t maskValueA = shuffleMask[i * 2];
3321 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3322 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3324 "ShuffleOp mask not interleaving");
3336 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3346 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3351 auto vectorTy = cast<VectorType>(dest.
getType());
3352 build(builder, result, source, dest,
3357 Value source,
Value dest, int64_t position) {
3370 posVals.reserve(position.size());
3371 llvm::transform(position, std::back_inserter(posVals),
3373 build(builder, result, source, dest, posVals);
3382 build(builder, result, source, dest, dynamicPos,
3387 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3388 if (srcTy.getRank() == 0)
3390 "expected a scalar instead of a 0-d vector as the source operand");
3393 auto destVectorType = getDestVectorType();
3394 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3396 "expected position attribute of rank no greater than dest vector rank");
3397 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3398 if (srcVectorType &&
3399 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3400 static_cast<unsigned>(destVectorType.getRank())))
3401 return emitOpError(
"expected position attribute rank + source rank to "
3402 "match dest vector rank");
3403 if (!srcVectorType &&
3404 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3406 "expected position attribute rank to match the dest vector rank");
3408 if (
auto attr = dyn_cast<Attribute>(pos)) {
3409 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3411 destVectorType.getDimSize(idx))) {
3412 return emitOpError(
"expected position attribute #")
3414 <<
" to be a non-negative integer smaller than the "
3416 "dest vector dimension";
3429 assert(positions.size() <= completePositions.size() &&
3430 "positions size must be less than or equal to destTy rank");
3431 copy(positions, completePositions.begin());
3446 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3447 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3448 srcVecType.getNumElements())
3451 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3464 Value splat = getScalarSplatSource(op.getValueToStore());
3465 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3493 class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3499 VectorType destTy = op.getDestVectorType();
3500 if (destTy.isScalable())
3504 if (
auto insertOp = dyn_cast<InsertOp>(user))
3505 if (insertOp.getDest() == op.getResult())
3508 InsertOp currentOp = op;
3512 if (currentOp.hasDynamicPosition())
3515 chainInsertOps.push_back(currentOp);
3516 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3519 if (currentOp && !currentOp->hasOneUse())
3523 int64_t vectorSize = destTy.getNumElements();
3524 int64_t initializedCount = 0;
3530 for (
auto insertOp : chainInsertOps) {
3532 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3536 int64_t insertBeginPosition =
3541 int64_t insertSize = 1;
3542 if (
auto srcVectorType =
3543 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3544 insertSize = srcVectorType.getNumElements();
3546 assert(insertBeginPosition + insertSize <= vectorSize &&
3547 "insert would overflow the vector");
3549 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3550 insertBeginPosition + insertSize)) {
3551 if (initializedDestIdxs[index])
3553 initializedDestIdxs[index] =
true;
3559 pendingInsertPos.push_back(insertBeginPosition);
3560 pendingInsertSize.push_back(insertSize);
3561 pendingInsertValues.push_back(insertOp.getValueToStore());
3563 if (initializedCount == vectorSize)
3568 if (initializedCount != vectorSize)
3572 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3573 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3574 pendingInsertValues))) {
3575 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3577 if (!srcVectorType) {
3578 elements[insertBeginPosition] = valueToStore;
3583 srcVectorType.getElementType());
3585 auto elementsToInsert = vector::ToElementsOp::create(
3586 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3587 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3588 elements[insertBeginPosition + linearIdx] =
3589 elementsToInsert.getResult(linearIdx);
3603 int64_t maxVectorSizeFoldThreshold) {
3604 if (insertOp.hasDynamicPosition())
3607 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3615 VectorType destTy = insertOp.getDestVectorType();
3616 if (destTy.isScalable())
3620 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3621 !insertOp->hasOneUse())
3625 int64_t insertBeginPosition =
3628 Type destEltType = destTy.getElementType();
3632 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3633 for (
auto value : denseSource.getValues<
Attribute>())
3639 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3640 copy(insertedValues, allValues.begin() + insertBeginPosition);
3649 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3653 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3656 insertOp.
setOperand(1, destInsert.getDest());
3657 return insertOp.getResult();
3662 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3663 InsertChainFullyInitialized>(context);
3669 constexpr int64_t vectorSizeFoldThreshold = 256;
3673 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3674 return getValueToStore();
3684 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3687 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3688 vectorSizeFoldThreshold)) {
3692 return inplaceFolded;
3714 template <
typename OpType>
3716 ArrayAttr arrayAttr,
3718 StringRef attrName) {
3719 if (arrayAttr.size() > shape.size())
3720 return op.emitOpError(
"expected ")
3721 << attrName <<
" attribute of rank no greater than vector rank";
3728 template <
typename OpType>
3729 static LogicalResult
3731 int64_t
max, StringRef attrName,
3732 bool halfOpen =
true) {
3733 for (
auto attr : arrayAttr) {
3734 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3738 if (val < min || val >= upper)
3739 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3740 <<
min <<
", " << upper <<
")";
3748 template <
typename OpType>
3749 static LogicalResult
3752 bool halfOpen =
true, int64_t
min = 0) {
3753 for (
auto [index, attrDimPair] :
3755 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3756 int64_t
max = std::get<1>(attrDimPair);
3759 if (val < min || val >=
max)
3760 return op.emitOpError(
"expected ")
3761 << attrName <<
" dimension " << index <<
" to be confined to ["
3762 <<
min <<
", " <<
max <<
")";
3772 template <
typename OpType>
3774 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3776 bool halfOpen =
true, int64_t
min = 1) {
3777 assert(arrayAttr1.size() <= shape.size());
3778 assert(arrayAttr2.size() <= shape.size());
3779 for (
auto [index, it] :
3781 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3782 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3783 int64_t
max = std::get<2>(it);
3786 if (val1 + val2 < 0 || val1 + val2 >=
max)
3787 return op.emitOpError(
"expected sum(")
3788 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3789 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3796 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3803 auto sourceVectorType = getSourceVectorType();
3804 auto destVectorType = getDestVectorType();
3805 auto offsets = getOffsetsAttr();
3806 auto strides = getStridesAttr();
3807 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3809 "expected offsets of same size as destination vector rank");
3810 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3811 return emitOpError(
"expected strides of same size as source vector rank");
3812 if (sourceVectorType.getRank() > destVectorType.getRank())
3814 "expected source rank to be no greater than destination rank");
3816 auto sourceShape = sourceVectorType.getShape();
3817 auto destShape = destVectorType.getShape();
3819 destShape.size() - sourceShape.size(), 0);
3820 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3821 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3822 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3831 offName,
"source vector shape",
3835 unsigned rankDiff = destShape.size() - sourceShape.size();
3836 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3837 if (sourceVectorType.getScalableDims()[idx] !=
3838 destVectorType.getScalableDims()[idx + rankDiff]) {
3839 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3842 if (sourceVectorType.getScalableDims()[idx]) {
3843 auto sourceSize = sourceShape[idx];
3844 auto destSize = destShape[idx + rankDiff];
3845 if (sourceSize != destSize) {
3846 return emitOpError(
"expected size at idx=")
3848 << (
" to match the corresponding base size from the input "
3850 << sourceSize << (
" vs ") << destSize << (
")");
3860 class FoldInsertStridedSliceSplat final
3865 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3868 auto dst = insertStridedSliceOp.getDest();
3869 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3870 if (!splat || getScalarSplatSource(dst) != splat)
3873 rewriter.
replaceOp(insertStridedSliceOp, dst);
3880 class FoldInsertStridedSliceOfExtract final
3885 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3887 auto extractStridedSliceOp =
3888 insertStridedSliceOp.getValueToStore()
3889 .getDefiningOp<vector::ExtractStridedSliceOp>();
3891 if (!extractStridedSliceOp)
3894 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3898 if (extractStridedSliceOp.getStrides() !=
3899 insertStridedSliceOp.getStrides() ||
3900 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3903 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3910 class InsertStridedSliceConstantFolder final
3917 static constexpr int64_t vectorSizeFoldThreshold = 256;
3928 VectorType destTy = destVector.getType();
3929 if (destTy.isScalable())
3933 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3934 !destVector.hasOneUse())
3943 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3947 if (op.hasNonUnitStrides())
3950 VectorType sliceVecTy = sourceValue.getType();
3952 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3962 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3963 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3964 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3965 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3968 currDestPosition.begin() + rankDifference, currDestPosition.end());
3972 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3973 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3974 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3975 "Invalid slice element");
3976 newValues[linearizedPosition] = *sliceValuesIt;
3989 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3991 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3992 InsertStridedSliceConstantFolder>(context);
3995 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3996 if (getSourceVectorType() == getDestVectorType())
3997 return getValueToStore();
4013 p <<
" " << getLhs() <<
", " << getRhs();
4015 p <<
", " << getAcc();
4018 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4029 if (operandsInfo.size() < 2)
4031 "expected at least 2 operands");
4032 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4033 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4036 "expected vector type for operand #1");
4041 vRHS.getScalableDims()[0]};
4043 vLHS.getElementType(), scalableDimsRes);
4047 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4053 OuterProductOp::getKindAttrName(result.
name),
4055 OuterProductOp::getDefaultKind()));
4061 (operandsInfo.size() > 2 &&
4067 Type tRHS = getOperandTypeRHS();
4068 VectorType vLHS = getOperandVectorTypeLHS(),
4069 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4070 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4072 if (vLHS.getRank() != 1)
4073 return emitOpError(
"expected 1-d vector for operand #1");
4077 if (vRHS.getRank() != 1)
4078 return emitOpError(
"expected 1-d vector for operand #2");
4079 if (vRES.getRank() != 2)
4080 return emitOpError(
"expected 2-d vector result");
4081 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4082 return emitOpError(
"expected #1 operand dim to match result dim #1");
4083 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4084 return emitOpError(
"expected #2 operand dim to match result dim #2");
4085 if (vLHS.isScalable() && !vRHS.isScalable()) {
4089 "expected either both or only #2 operand dim to be scalable");
4093 if (vRES.getRank() != 1)
4094 return emitOpError(
"expected 1-d vector result");
4095 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4096 return emitOpError(
"expected #1 operand dim to match result dim #1");
4099 if (vACC && vACC != vRES)
4100 return emitOpError(
"expected operand #3 of same type as result type");
4104 return emitOpError(
"unsupported outerproduct type");
4113 Type OuterProductOp::getExpectedMaskType() {
4114 auto vecType = this->getResultVectorType();
4117 vecType.getScalableDims());
4129 ArrayAttr offsets, ArrayAttr sizes,
4130 ArrayAttr strides) {
4131 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4133 shape.reserve(vectorType.getRank());
4135 for (
unsigned e = offsets.size(); idx < e; ++idx)
4136 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4137 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4138 shape.push_back(vectorType.getShape()[idx]);
4141 vectorType.getScalableDims());
4154 offsetsAttr, sizesAttr, stridesAttr));
4155 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
4159 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
4164 auto type = getSourceVectorType();
4165 auto offsets = getOffsetsAttr();
4166 auto sizes = getSizesAttr();
4167 auto strides = getStridesAttr();
4168 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4170 "expected offsets, sizes and strides attributes of same size");
4172 auto shape = type.getShape();
4173 auto offName = getOffsetsAttrName();
4174 auto sizesName = getSizesAttrName();
4175 auto stridesName = getStridesAttrName();
4191 shape, offName, sizesName,
4196 offsets, sizes, strides);
4197 if (getResult().
getType() != resultType)
4198 return emitOpError(
"expected result type to be ") << resultType;
4200 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4201 if (type.getScalableDims()[idx]) {
4202 auto inputDim = type.getShape()[idx];
4203 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4204 if (inputDim != inputSize)
4205 return emitOpError(
"expected size at idx=")
4207 << (
" to match the corresponding base size from the input "
4209 << inputSize << (
" vs ") << inputDim << (
")");
4219 static LogicalResult
4222 auto getElement = [](ArrayAttr array,
int idx) {
4223 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4225 ArrayAttr extractOffsets = op.getOffsets();
4227 ArrayAttr extractSizes = op.getSizes();
4228 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4230 if (op.getSourceVectorType().getRank() !=
4231 insertOp.getSourceVectorType().getRank())
4233 ArrayAttr insertOffsets = insertOp.getOffsets();
4234 ArrayAttr insertStrides = insertOp.getStrides();
4237 if (extractOffsets.size() > insertOffsets.size())
4239 bool patialoverlap =
false;
4240 bool disjoint =
false;
4242 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4243 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4245 int64_t start = getElement(insertOffsets, dim);
4246 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4247 int64_t offset = getElement(extractOffsets, dim);
4248 int64_t size = getElement(extractSizes, dim);
4250 if (start <= offset && offset < end) {
4253 if (offset + size > end)
4254 patialoverlap =
true;
4255 offsetDiffs.push_back(offset - start);
4262 if (!disjoint && !patialoverlap) {
4263 op.setOperand(insertOp.getValueToStore());
4272 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4287 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4292 if (op.hasNonUnitStrides())
4295 VectorType sourceVecTy = op.getSourceVectorType();
4299 VectorType sliceVecTy = op.getType();
4301 int64_t rank = sliceVecTy.getRank();
4313 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4315 sliceValues.reserve(sliceVecTy.getNumElements());
4318 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
4319 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4321 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4322 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4324 assert(
static_cast<int64_t
>(sliceValues.size()) ==
4325 sliceVecTy.getNumElements() &&
4326 "Invalid number of slice elements");
4330 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4331 if (getSourceVectorType() == getResult().
getType())
4338 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4363 class StridedSliceCreateMaskFolder final
4368 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4370 Location loc = extractStridedSliceOp.getLoc();
4374 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4378 if (extractStridedSliceOp.hasNonUnitStrides())
4391 sliceMaskDimSizes.reserve(maskDimSizes.size());
4395 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4396 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4400 IntegerAttr offsetAttr =
4402 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4403 Value sliceMaskDimSize =
4404 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4405 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4410 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4414 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4422 class StridedSliceConstantMaskFolder final
4427 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4431 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4432 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4433 if (!constantMaskOp)
4436 if (extractStridedSliceOp.hasNonUnitStrides())
4449 sliceMaskDimSizes.reserve(maskDimSizes.size());
4450 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4451 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4452 int64_t sliceMaskDimSize =
std::max(
4453 static_cast<int64_t
>(0),
4454 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4455 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4458 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4459 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4460 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4463 if (llvm::is_contained(sliceMaskDimSizes, 0))
4464 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4469 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4477 class StridedSliceBroadcast final
4489 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4490 auto dstVecType = llvm::cast<VectorType>(op.getType());
4491 unsigned dstRank = dstVecType.getRank();
4492 unsigned rankDiff = dstRank - srcRank;
4496 bool needsSlice =
false;
4497 for (
unsigned i = 0; i < srcRank; i++) {
4498 if (srcVecType.getDimSize(i) != 1 &&
4499 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4510 for (
unsigned i = 0; i < srcRank; i++) {
4511 if (srcVecType.getDimSize(i) == 1) {
4519 source = ExtractStridedSliceOp::create(
4520 rewriter, op->getLoc(), source, offsets, sizes,
4529 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4536 Value splat = getScalarSplatSource(op.getSource());
4560 class ContiguousExtractStridedSliceToExtract final
4567 if (op.hasNonUnitStrides())
4569 Value source = op.getOperand();
4570 auto sourceType = cast<VectorType>(source.
getType());
4571 if (sourceType.isScalable() || sourceType.getRank() == 0)
4580 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4581 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4588 if (numOffsets == 0)
4593 if (numOffsets == sourceType.getRank() &&
4594 static_cast<int>(sizes.size()) == sourceType.getRank())
4598 for (
int i = 0; i < numOffsets; ++i) {
4606 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4607 sizes[numOffsets] == 1) {
4612 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4613 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4622 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4626 results.
add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4627 StridedSliceBroadcast, StridedSliceSplat,
4628 ContiguousExtractStridedSliceToExtract>(context);
4637 VectorType vectorType,
Value source,
4638 ValueRange indices, std::optional<Value> padding,
4639 AffineMapAttr permutationMapAttr,
4640 ArrayAttr inBoundsAttr) {
4642 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4644 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4645 build(builder, result, vectorType, source, indices, permutationMapAttr,
4646 *padding,
Value(), inBoundsAttr);
4651 VectorType vectorType,
Value source,
4652 ValueRange indices, std::optional<Value> padding,
4656 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4660 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4662 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4663 build(builder, result, vectorType, source, indices, *padding,
4664 permutationMapAttr, inBoundsAttr);
4669 VectorType vectorType,
Value source,
4670 ValueRange indices, std::optional<Value> padding,
4673 llvm::cast<ShapedType>(source.
getType()), vectorType);
4675 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4679 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4681 padding = ub::PoisonOp::create(builder, result.
location, elemType);
4682 build(builder, result, vectorType, source, indices, permutationMapAttr,
4684 Value(), inBoundsAttr);
4687 template <
typename EmitFun>
4689 EmitFun emitOpError) {
4691 for (
auto expr : permutationMap.
getResults()) {
4692 auto dim = dyn_cast<AffineDimExpr>(expr);
4693 auto zero = dyn_cast<AffineConstantExpr>(expr);
4695 if (zero.getValue() != 0) {
4697 "requires a projected permutation_map (at most one dim or the zero "
4698 "constant can appear in each result)");
4703 return emitOpError(
"requires a projected permutation_map (at most one "
4704 "dim or the zero constant can appear in each result)");
4706 if (seen[dim.getPosition()]) {
4708 "requires a permutation_map that is a permutation (found one dim "
4709 "used more than once)");
4711 seen[dim.getPosition()] =
true;
4716 static LogicalResult
4718 VectorType vectorType, VectorType maskType,
4719 VectorType inferredMaskType,
AffineMap permutationMap,
4720 ArrayAttr inBounds) {
4721 if (op->hasAttr(
"masked")) {
4722 return op->emitOpError(
"masked attribute has been removed. "
4723 "Use in_bounds instead.");
4726 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4727 return op->emitOpError(
4728 "requires source to be a memref or ranked tensor type");
4730 auto elementType = shapedType.getElementType();
4731 DataLayout dataLayout = DataLayout::closest(op);
4732 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4734 unsigned sourceVecSize =
4736 vectorElementType.getShape().back();
4737 unsigned resultVecSize =
4739 vectorType.getShape().back();
4740 if (resultVecSize % sourceVecSize != 0)
4741 return op->emitOpError(
4742 "requires the bitwidth of the minor 1-D vector to be an integral "
4743 "multiple of the bitwidth of the minor 1-D vector of the source");
4745 unsigned sourceVecEltRank = vectorElementType.getRank();
4746 unsigned resultVecRank = vectorType.getRank();
4747 if (sourceVecEltRank > resultVecRank)
4748 return op->emitOpError(
4749 "requires source vector element and vector result ranks to match.");
4750 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4753 return op->emitOpError(
"requires a permutation_map with result dims of "
4754 "the same rank as the vector type");
4757 return op->emitOpError(
"does not support masks with vector element type");
4760 unsigned minorSize =
4761 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4762 unsigned resultVecSize =
4765 return op->emitOpError(
4766 "requires the bitwidth of the minor 1-D vector to be an integral "
4767 "multiple of the bitwidth of the source element type");
4771 return op->emitOpError(
"requires a permutation_map with result dims of "
4772 "the same rank as the vector type");
4776 return op->emitOpError(
"requires permutation_map without symbols");
4778 if (permutationMap.
getNumInputs() != shapedType.getRank())
4779 return op->emitOpError(
"requires a permutation_map with input dims of the "
4780 "same rank as the source type");
4782 if (maskType && maskType != inferredMaskType)
4783 return op->emitOpError(
"inferred mask type (")
4784 << inferredMaskType <<
") and mask operand type (" << maskType
4787 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4788 return op->emitOpError(
"expects the in_bounds attr of same rank "
4789 "as permutation_map results: ")
4791 <<
" vs inBounds of size: " << inBounds.size();
4798 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4799 if (op.getPermutationMap().isMinorIdentity())
4800 elidedAttrs.push_back(op.getPermutationMapAttrName());
4802 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4803 elidedAttrs.push_back(op.getInBoundsAttrName());
4810 p <<
", " << getMask();
4819 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4824 if (maskShape.empty())
4825 maskShape.push_back(1);
4847 if (hasMask.succeeded()) {
4854 if (types.size() != 2)
4855 return parser.
emitError(typesLoc,
"requires two types");
4857 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4858 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4859 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4860 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4862 return parser.
emitError(typesLoc,
"requires vector type");
4863 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4867 if (shapedType.getRank() <
4870 "expected a custom permutation_map when "
4871 "rank(source) != rank(destination)");
4875 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4877 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4879 if (!inBoundsAttr) {
4889 if (hasMask.succeeded()) {
4890 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4892 maskInfo.
location,
"does not support masks with vector element type");
4895 "expected the same rank for the vector and the "
4896 "results of the permutation map");
4904 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4906 {1, static_cast<int32_t>(indexInfo.size()), 1,
4907 static_cast<int32_t>(hasMask.succeeded())}));
4913 ShapedType shapedType = getShapedType();
4915 VectorType maskType = getMaskType();
4916 auto paddingType = getPadding().getType();
4917 auto permutationMap = getPermutationMap();
4918 VectorType inferredMaskType =
4921 auto sourceElementType = shapedType.getElementType();
4923 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4924 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4927 shapedType, vectorType, maskType,
4928 inferredMaskType, permutationMap, getInBounds())))
4931 if (
auto sourceVectorElementType =
4932 llvm::dyn_cast<VectorType>(sourceElementType)) {
4935 if (sourceVectorElementType != paddingType)
4937 "requires source element type and padding type to match.");
4941 if (!VectorType::isValidElementType(paddingType))
4942 return emitOpError(
"requires valid padding vector elemental type");
4945 if (paddingType != sourceElementType)
4947 "requires formal padding and source of the same elemental type");
4951 [&](Twine t) {
return emitOpError(t); });
4958 Type TransferReadOp::getExpectedMaskType() {
4966 return cast<VectorType>(getVector().
getType());
4969 template <
typename TransferOp>
4970 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4973 if (op.getShapedType().isDynamicDim(indicesIdx))
4975 Value index = op.getIndices()[indicesIdx];
4977 if (!cstOp.has_value())
4980 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4981 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4983 return cstOp.value() + vectorSize <= sourceSize;
4986 template <
typename TransferOp>
4990 if (op.getTransferRank() == 0)
4995 newInBounds.reserve(op.getTransferRank());
5000 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5002 if (op.isDimInBounds(i)) {
5003 newInBounds.push_back(
true);
5008 bool inBounds =
false;
5009 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5012 dimExpr.getPosition());
5013 nonBcastDims.push_back(i);
5016 newInBounds.push_back(inBounds);
5024 bool allNonBcastDimsInBounds = llvm::all_of(
5025 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5026 if (allNonBcastDimsInBounds) {
5029 newInBounds[idx] =
true;
5041 template <
typename TransferOp>
5043 auto mask = op.getMask();
5050 op.getMaskMutable().clear();
5064 static Value foldRAW(TransferReadOp readOp) {
5065 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5067 auto defWrite = readOp.getBase().
getDefiningOp<vector::TransferWriteOp>();
5070 return defWrite.getVector();
5072 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5073 cast<VectorTransferOpInterface>(readOp.getOperation())))
5075 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5081 if (
Value vec = foldRAW(*
this))
5095 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5099 void TransferReadOp::getEffects(
5102 if (llvm::isa<MemRefType>(getShapedType()))
5108 if (hasPureTensorSemantics())
5136 struct TransferReadAfterWriteToBroadcast
5142 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5146 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5149 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5153 if (readOp.getTransferChunkAccessed() !=
5154 defWrite.getTransferChunkAccessed())
5173 if (readOp.getMask() || defWrite.getMask())
5176 if (readOp.getIndices() != defWrite.getIndices())
5179 Value vec = defWrite.getVector();
5199 broadcastShape[pos.value()] = destShape[pos.index()];
5200 broadcastScalableFlags[pos.value()] =
5201 readOp.getVectorType().getScalableDims()[pos.index()];
5204 broadcastShape, defWrite.getVectorType().getElementType(),
5205 broadcastScalableFlags);
5206 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5217 results.
add<TransferReadAfterWriteToBroadcast>(context);
5220 FailureOr<std::optional<SmallVector<Value>>>
5221 TransferReadOp::bubbleDownCasts(
OpBuilder &builder) {
5222 if (!hasPureBufferSemantics())
5235 AffineMapAttr permutationMapAttr,
5237 ArrayAttr inBoundsAttr) {
5238 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5239 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5240 mask, inBoundsAttr);
5246 AffineMapAttr permutationMapAttr,
5247 ArrayAttr inBoundsAttr) {
5248 build(builder, result, vector, dest, indices, permutationMapAttr,
5249 Value(), inBoundsAttr);
5260 (inBounds && !inBounds.value().empty())
5263 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5264 build(builder, result, vector, dest, indices, permutationMapAttr,
5265 Value(), inBoundsAttr);
5273 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5275 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5276 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5292 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5297 if (types.size() != 2)
5298 return parser.
emitError(typesLoc,
"requires two types");
5300 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5302 return parser.
emitError(typesLoc,
"requires vector type");
5303 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5304 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5305 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5306 auto permMapAttrName =
5307 TransferWriteOp::getPermutationMapAttrName(result.
name);
5311 if (shapedType.getRank() <
5314 "expected a custom permutation_map when "
5315 "rank(source) != rank(destination)");
5319 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5321 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
5323 if (!inBoundsAttr) {
5332 if (hasMask.succeeded()) {
5333 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5335 maskInfo.
location,
"does not support masks with vector element type");
5338 "expected the same rank for the vector and the "
5339 "results of the permutation map");
5345 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5347 {1, 1, static_cast<int32_t>(indexInfo.size()),
5348 static_cast<int32_t>(hasMask.succeeded())}));
5349 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5356 p <<
", " << getMask();
5363 ShapedType shapedType = getShapedType();
5365 VectorType maskType = getMaskType();
5366 auto permutationMap = getPermutationMap();
5367 VectorType inferredMaskType =
5371 if (llvm::size(
getIndices()) != shapedType.getRank())
5372 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5376 if (hasBroadcastDim())
5377 return emitOpError(
"should not have broadcast dimensions");
5380 shapedType, vectorType, maskType,
5381 inferredMaskType, permutationMap, getInBounds())))
5385 [&](Twine t) {
return emitOpError(t); });
5394 Type TransferWriteOp::getExpectedMaskType() {
5401 Value TransferWriteOp::getVector() {
return getOperand(0); }
5403 return cast<VectorType>(getValueToStore().
getType());
5426 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5430 if (write.getTransferRank() == 0)
5432 auto rankedTensorType =
5433 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5435 if (!rankedTensorType)
5438 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5442 if (read.getTransferRank() == 0)
5445 if (!read.getPermutationMap().isMinorIdentity() ||
5446 !write.getPermutationMap().isMinorIdentity())
5449 if (read.getTransferRank() != write.getTransferRank())
5452 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5455 if (read.getBase().getType() != rankedTensorType)
5458 if (read.getVectorType() != write.getVectorType())
5461 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5464 auto isNotConstantZero = [](
Value v) {
5466 return !cstOp.has_value() || cstOp.value() != 0;
5468 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5469 llvm::any_of(write.getIndices(), isNotConstantZero))
5472 results.push_back(read.getBase());
5476 static bool checkSameValueWAR(vector::TransferReadOp read,
5477 vector::TransferWriteOp write) {
5478 return read.getBase() == write.getBase() &&
5479 read.getIndices() == write.getIndices() &&
5480 read.getPermutationMap() == write.getPermutationMap() &&
5481 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5498 static LogicalResult foldWAR(TransferWriteOp write,
5500 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5502 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5506 if (!checkSameValueWAR(read, write))
5508 results.push_back(read.getBase());
5512 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5514 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5516 if (succeeded(foldWAR(*
this, results)))
5528 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5532 void TransferWriteOp::getEffects(
5535 if (llvm::isa<MemRefType>(getShapedType()))
5541 if (hasPureTensorSemantics())
5576 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5578 vector::TransferWriteOp writeToModify = writeOp;
5580 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5584 writeToModify.getBaseMutable().assign(defWrite.getBase());
5589 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5590 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5594 if (!defWrite->hasOneUse())
5596 writeToModify = defWrite;
5597 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5626 struct SwapExtractSliceOfTransferWrite
5633 if (!insertOp.hasUnitStride())
5636 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5637 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5639 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5640 if (!transferOp || !transferOp->hasOneUse())
5645 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5647 "use-def chain is rank-reducing");
5651 if (!extractOp.hasZeroOffset()) {
5653 "ExtractSliceOp has non-zero offset");
5657 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5661 "TranferWriteOp has non-zero offset");
5665 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5667 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5670 for (
auto [insertSize, extractSize] :
5671 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5674 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5679 assert(transferOp.getVectorType().hasStaticShape() &&
5680 "expected vector to have a static shape");
5683 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5684 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5686 insertOp,
"TransferWriteOp may not write the full tensor.");
5692 auto newExtractOp = tensor::ExtractSliceOp::create(
5693 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5694 insertOp.getDest(), insertOp.getMixedOffsets(),
5695 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5696 auto newTransferWriteOp = TransferWriteOp::create(
5697 rewriter, transferOp.getLoc(), transferOp.getVector(),
5698 newExtractOp.getResult(), transferOp.getIndices(),
5699 transferOp.getPermutationMapAttr(),
5702 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5712 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5715 FailureOr<std::optional<SmallVector<Value>>>
5716 TransferWriteOp::bubbleDownCasts(
OpBuilder &builder) {
5717 if (!hasPureBufferSemantics())
5727 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5729 MemRefType memRefTy) {
5732 if (!vecTy.isScalable() &&
5733 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5736 if (!memRefTy.isLastDimUnitStride())
5737 return op->
emitOpError(
"most minor memref dim must have unit stride");
5745 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5748 if (memRefTy.getRank() < resVecTy.getRank())
5750 "destination memref has lower rank than the result vector");
5753 Type memElemTy = memRefTy.getElementType();
5754 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5755 if (memVecTy != resVecTy)
5756 return emitOpError(
"base memref and result vector types should match");
5757 memElemTy = memVecTy.getElementType();
5760 if (resVecTy.getElementType() != memElemTy)
5761 return emitOpError(
"base and result element types should match");
5762 if (llvm::size(
getIndices()) != memRefTy.getRank())
5763 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5773 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5777 FailureOr<std::optional<SmallVector<Value>>>
5778 LoadOp::bubbleDownCasts(
OpBuilder &builder) {
5791 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5794 if (memRefTy.getRank() < valueVecTy.getRank())
5795 return emitOpError(
"source memref has lower rank than the vector to store");
5798 Type memElemTy = memRefTy.getElementType();
5799 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5800 if (memVecTy != valueVecTy)
5802 "base memref and valueToStore vector types should match");
5803 memElemTy = memVecTy.getElementType();
5806 if (valueVecTy.getElementType() != memElemTy)
5807 return emitOpError(
"base and valueToStore element type should match");
5808 if (llvm::size(
getIndices()) != memRefTy.getRank())
5809 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5813 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5818 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5822 FailureOr<std::optional<SmallVector<Value>>>
5823 StoreOp::bubbleDownCasts(
OpBuilder &builder) {
5833 VectorType maskVType = getMaskVectorType();
5834 VectorType passVType = getPassThruVectorType();
5838 if (resVType.getElementType() != memType.getElementType())
5839 return emitOpError(
"base and result element type should match");
5840 if (llvm::size(
getIndices()) != memType.getRank())
5841 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5842 if (resVType.getShape() != maskVType.getShape())
5843 return emitOpError(
"expected result shape to match mask shape");
5844 if (resVType != passVType)
5845 return emitOpError(
"expected pass_thru of same type as result type");
5858 load, load.getType(), load.getBase(), load.getIndices());
5861 rewriter.
replaceOp(load, load.getPassThru());
5866 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5873 results.
add<MaskedLoadFolder>(context);
5882 FailureOr<std::optional<SmallVector<Value>>>
5883 MaskedLoadOp::bubbleDownCasts(
OpBuilder &builder) {
5893 VectorType maskVType = getMaskVectorType();
5897 if (valueVType.getElementType() != memType.getElementType())
5898 return emitOpError(
"base and valueToStore element type should match");
5899 if (llvm::size(
getIndices()) != memType.getRank())
5900 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5901 if (valueVType.getShape() != maskVType.getShape())
5902 return emitOpError(
"expected valueToStore shape to match mask shape");
5915 store, store.getValueToStore(), store.getBase(), store.getIndices());
5923 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5930 results.
add<MaskedStoreFolder>(context);
5933 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5938 FailureOr<std::optional<SmallVector<Value>>>
5939 MaskedStoreOp::bubbleDownCasts(
OpBuilder &builder) {
5949 VectorType indVType = getIndexVectorType();
5950 VectorType maskVType = getMaskVectorType();
5952 ShapedType baseType = getBaseType();
5954 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5955 return emitOpError(
"requires base to be a memref or ranked tensor type");
5957 if (resVType.getElementType() != baseType.getElementType())
5958 return emitOpError(
"base and result element type should match");
5959 if (llvm::size(getOffsets()) != baseType.getRank())
5960 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5961 if (resVType.getShape() != indVType.getShape())
5962 return emitOpError(
"expected result dim to match indices dim");
5963 if (resVType.getShape() != maskVType.getShape())
5964 return emitOpError(
"expected result dim to match mask dim");
5965 if (resVType != getPassThruVectorType())
5966 return emitOpError(
"expected pass_thru of same type as result type");
5974 Type GatherOp::getExpectedMaskType() {
5975 auto vecType = this->getIndexVectorType();
5978 vecType.getScalableDims());
5981 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5986 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5987 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5988 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5999 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6012 rewriter.
replaceOp(gather, gather.getPassThru());
6017 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6028 if (!isa<MemRefType>(op.getBase().getType()))
6031 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6035 op.getOffsets(), op.getMask(),
6044 results.
add<GatherFolder, FoldContiguousGather>(context);
6047 FailureOr<std::optional<SmallVector<Value>>>
6048 GatherOp::bubbleDownCasts(
OpBuilder &builder) {
6058 VectorType indVType = getIndexVectorType();
6059 VectorType maskVType = getMaskVectorType();
6063 if (valueVType.getElementType() != memType.getElementType())
6064 return emitOpError(
"base and valueToStore element type should match");
6065 if (llvm::size(getOffsets()) != memType.getRank())
6066 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6067 if (valueVType.getShape() != indVType.getShape())
6068 return emitOpError(
"expected valueToStore dim to match indices dim");
6069 if (valueVType.getShape() != maskVType.getShape())
6070 return emitOpError(
"expected valueToStore dim to match mask dim");
6089 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6100 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6104 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6112 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6115 FailureOr<std::optional<SmallVector<Value>>>
6116 ScatterOp::bubbleDownCasts(
OpBuilder &builder) {
6126 VectorType maskVType = getMaskVectorType();
6127 VectorType passVType = getPassThruVectorType();
6131 if (resVType.getElementType() != memType.getElementType())
6132 return emitOpError(
"base and result element type should match");
6133 if (llvm::size(
getIndices()) != memType.getRank())
6134 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6135 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6136 return emitOpError(
"expected result dim to match mask dim");
6137 if (resVType != passVType)
6138 return emitOpError(
"expected pass_thru of same type as result type");
6151 expand, expand.getType(), expand.getBase(), expand.getIndices());
6154 rewriter.
replaceOp(expand, expand.getPassThru());
6159 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6166 results.
add<ExpandLoadFolder>(context);
6169 FailureOr<std::optional<SmallVector<Value>>>
6170 ExpandLoadOp::bubbleDownCasts(
OpBuilder &builder) {
6180 VectorType maskVType = getMaskVectorType();
6184 if (valueVType.getElementType() != memType.getElementType())
6185 return emitOpError(
"base and valueToStore element type should match");
6186 if (llvm::size(
getIndices()) != memType.getRank())
6187 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6188 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6189 return emitOpError(
"expected valueToStore dim to match mask dim");
6194 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6202 compress, compress.getValueToStore(), compress.getBase(),
6203 compress.getIndices());
6211 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6218 results.
add<CompressStoreFolder>(context);
6221 FailureOr<std::optional<SmallVector<Value>>>
6222 CompressStoreOp::bubbleDownCasts(
OpBuilder &builder) {
6233 setResultRanges(getResult(), argRanges.front());
6238 VectorType sourceType = getSourceVectorType();
6239 VectorType resultType = getResultVectorType();
6242 if (sourceType.getElementType() != resultType.getElementType())
6243 return emitOpError(
"has different source and result element types");
6246 int64_t sourceNElms = sourceType.getNumElements();
6247 int64_t resultNElms = resultType.getNumElements();
6248 if (sourceNElms != resultNElms) {
6249 return emitOpError() <<
"has different number of elements at source ("
6250 << sourceNElms <<
") and result (" << resultNElms
6255 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6256 int64_t resultNScalableDims = resultType.getNumScalableDims();
6257 if (sourceNScalableDims != resultNScalableDims)
6258 return emitOpError() <<
"has different number of scalable dims at source ("
6259 << sourceNScalableDims <<
") and result ("
6260 << resultNScalableDims <<
")";
6269 static bool isOrderPreserving(TransposeOp transpose) {
6271 VectorType sourceType = transpose.getSourceVectorType();
6274 auto isNonScalableUnitDim = [&](int64_t dim) {
6275 return inShape[dim] == 1 && !inDimIsScalable[dim];
6277 int64_t current = 0;
6278 for (
auto p : permutation) {
6279 if (!isNonScalableUnitDim(p)) {
6291 VectorType resultType =
getType();
6294 if (getSource().
getType() == resultType)
6298 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6299 setOperand(precedingShapeCast.getSource());
6304 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6305 if (isOrderPreserving(transpose)) {
6306 setOperand(transpose.getVector());
6314 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6315 if (bcastOp.getSourceType() == resultType)
6316 return bcastOp.getSource();
6320 if (
auto denseAttr =
6321 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6322 return denseAttr.reshape(
getType());
6325 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6338 static VectorType trimTrailingOneDims(VectorType oldType) {
6345 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6346 newShape = newShape.drop_back(1);
6347 newScalableDims = newScalableDims.drop_back(1);
6352 if (newShape.empty()) {
6353 newShape = oldShape.take_back();
6354 newScalableDims = oldScalableDims.take_back();
6357 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6372 class ShapeCastCreateMaskFolderTrailingOneDim final
6379 Value shapeOpSrc = shapeOp->getOperand(0);
6380 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6381 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6382 if (!createMaskOp && !constantMaskOp)
6385 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6386 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6388 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6389 if (newVecType != shapeOpResTy)
6392 auto numDimsToDrop =
6393 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6400 auto maskOperands = createMaskOp.getOperands();
6401 auto numMaskOperands = maskOperands.size();
6404 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6406 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6407 if (!constant || (constant.value() != 1))
6411 maskOperands.drop_back(numDimsToDrop);
6418 if (constantMaskOp) {
6419 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6420 auto numMaskOperands = maskDimSizes.size();
6423 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6425 if (maskDimSizes[i] != 1)
6429 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6443 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6450 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6454 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6455 bool srcIsScalar = !srcVectorType;
6463 if (srcVectorType) {
6464 if (srcVectorType.getNumElements() ==
6465 shapeCastOp.getResultVectorType().getNumElements()) {
6467 shapeCastOp, shapeCastOp.getResultVectorType(),
6468 broadcastOp.getSource());
6479 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6481 BroadcastableToResult::Success) {
6483 shapeCastOp, dstVectorType, broadcastOp.getSource());
6495 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6504 auto sourceVectorType = getSourceVectorType();
6505 auto resultVectorType = getResultVectorType();
6507 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6508 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6509 return emitOpError(
"dimension size mismatch at: ") << i;
6512 DataLayout dataLayout = DataLayout::closest(*
this);
6513 auto sourceElementBits =
6515 auto resultElementBits =
6518 if (sourceVectorType.getRank() == 0) {
6519 if (sourceElementBits != resultElementBits)
6520 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6521 "types must be equal");
6522 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6523 resultElementBits * resultVectorType.getShape().back()) {
6525 "source/result bitwidth of the minor 1-D vectors must be equal");
6537 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6538 if (getResult().
getType() == otherOp.getSource().getType())
6539 return otherOp.getSource();
6541 setOperand(otherOp.getSource());
6545 Attribute sourceConstant = adaptor.getSource();
6546 if (!sourceConstant)
6549 Type srcElemType = getSourceVectorType().getElementType();
6550 Type dstElemType = getResultVectorType().getElementType();
6552 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6553 if (floatPack.isSplat()) {
6554 auto splat = floatPack.getSplatValue<FloatAttr>();
6557 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6558 uint32_t bits =
static_cast<uint32_t
>(
6559 splat.getValue().bitcastToAPInt().getZExtValue());
6561 bits = (bits << 16) | (bits & 0xffff);
6562 APInt intBits(32, bits);
6563 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6569 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6570 if (intPack.isSplat()) {
6571 auto splat = intPack.getSplatValue<IntegerAttr>();
6573 if (llvm::isa<IntegerType>(dstElemType)) {
6578 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6579 APInt intBits = splat.getValue().zext(dstBitWidth);
6582 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6583 intBits = (intBits << srcBitWidth) | intBits;
6598 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6601 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6610 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6611 VectorType vectorType =
6615 memRefType.getMemorySpace()));
6619 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6620 if (!canonicalType.getLayout().isIdentity())
6621 return emitOpError(
"expects operand to be a memref with identity layout");
6622 if (!getResultMemRefType().getLayout().isIdentity())
6623 return emitOpError(
"expects result to be a memref with identity layout");
6624 if (getResultMemRefType().getMemorySpace() !=
6626 return emitOpError(
"expects result in same memory space");
6629 auto resultType = getResultMemRefType();
6633 "expects result and operand with same underlying scalar type: ")
6635 if (extractShape(sourceType) != extractShape(resultType))
6637 "expects concatenated result and operand shapes to be equal: ")
6648 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6651 for (
unsigned i = 0; i < permutation.size(); ++i) {
6652 transposedShape[i] = vt.getShape()[permutation[i]];
6653 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6658 transposedScalableDims));
6663 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6666 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6667 return splat.reshape(getResultVectorType());
6670 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6684 if (getSourceVectorType() == getResultVectorType() &&
6685 isOrderPreserving(*
this))
6692 VectorType vectorType = getSourceVectorType();
6693 VectorType resultType = getResultVectorType();
6694 int64_t rank = resultType.getRank();
6695 if (vectorType.getRank() != rank)
6696 return emitOpError(
"vector result rank mismatch: ") << rank;
6699 int64_t size = perm.size();
6701 return emitOpError(
"transposition length mismatch: ") << size;
6704 if (ta.value() < 0 || ta.value() >= rank)
6705 return emitOpError(
"transposition index out of range: ") << ta.value();
6706 if (seen[ta.value()])
6707 return emitOpError(
"duplicate position index: ") << ta.value();
6708 seen[ta.value()] =
true;
6709 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6710 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6715 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6716 return llvm::to_vector<4>(getResultVectorType().
getShape());
6721 setResultRanges(getResult(), argRanges.front());
6727 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6737 for (
auto index : permutation2)
6738 result.push_back(permutation1[index]);
6743 vector::TransposeOp parentTransposeOp =
6744 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6745 if (!parentTransposeOp)
6749 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6752 transposeOp, transposeOp.getResult().getType(),
6753 parentTransposeOp.getVector(), permutation);
6765 Value splat = getScalarSplatSource(transposeOp.getVector());
6770 transposeOp, transposeOp.getResultVectorType(), splat);
6776 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6782 Value transposeSrc = transpOp.getVector();
6783 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6784 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6785 if (!createMaskOp && !constantMaskOp)
6793 auto maskOperands = createMaskOp.getOperands();
6798 transpOp, transpOp.getResultVectorType(), newOperands);
6803 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6807 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6813 class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6820 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6823 if (!isOrderPreserving(transposeOp))
6826 VectorType resultType = transposeOp.getType();
6833 shapeCastOp.getSource());
6863 class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
6876 "not preceded by a broadcast");
6879 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
6880 VectorType outputType = transpose.getResultVectorType();
6883 bool inputIsScalar = !inputType;
6884 if (inputIsScalar) {
6892 int64_t inputRank = inputType.getRank();
6893 int64_t outputRank = transpose.getType().getRank();
6894 int64_t deltaRank = outputRank - inputRank;
6897 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6898 bool notOne = inputShape[inputIndex] != 1;
6899 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6900 bool groupEndFound = notOne || prevNotOne;
6901 if (groupEndFound) {
6902 int high = inputIndex + deltaRank;
6906 for (
int i = low; i < high; ++i) {
6907 if (permutation[i] < low || permutation[i] >= high) {
6909 transpose,
"permutation not local to group");
6923 vector::BroadcastableToResult::Success &&
6924 "not broadcastable directly to transpose output");
6935 void vector::TransposeOp::getCanonicalizationPatterns(
6937 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6938 FoldTransposeSplat, FoldTransposeBroadcast>(context);
6947 assert(
kind == ConstantMaskKind::AllTrue ||
6948 kind == ConstantMaskKind::AllFalse);
6949 build(builder, result, type,
6950 kind == ConstantMaskKind::AllTrue
6956 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6958 if (resultType.getRank() == 0) {
6959 if (getMaskDimSizes().size() != 1)
6960 return emitError(
"array attr must have length 1 for 0-D vectors");
6961 auto dim = getMaskDimSizes()[0];
6962 if (dim != 0 && dim != 1)
6963 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6968 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6970 "must specify array attr of size equal vector result rank");
6973 auto resultShape = resultType.getShape();
6974 auto resultScalableDims = resultType.getScalableDims();
6976 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6977 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6979 "array attr of size out of bounds of vector result dimension size");
6980 if (resultScalableDims[index] && maskDimSize != 0 &&
6981 maskDimSize != resultShape[index])
6983 "only supports 'none set' or 'all set' scalable dimensions");
6987 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6988 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6989 if (anyZeros && !allZeros)
6990 return emitOpError(
"expected all mask dim sizes to be zeros, "
6991 "as a result of conjunction with zero mask dim");
6995 bool ConstantMaskOp::isAllOnesMask() {
6998 if (resultType.getRank() == 0) {
6999 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7000 return getMaskDimSizes()[0] == 1;
7002 for (
const auto [resultSize, maskDimSize] :
7003 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7004 if (maskDimSize < resultSize)
7010 OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7014 auto createBoolSplat = [&](
bool x) {
7020 if (vectorSizes.empty()) {
7021 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7022 return createBoolSplat(bounds[0] == 1);
7025 if (bounds == vectorSizes)
7026 return createBoolSplat(
true);
7027 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7028 return createBoolSplat(
false);
7041 build(builder, result, type, operands);
7045 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7047 if (vectorType.getRank() == 0) {
7048 if (getNumOperands() != 1)
7050 "must specify exactly one operand for 0-D create_mask");
7051 }
else if (getNumOperands() !=
7052 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7054 "must specify an operand for each result vector dimension");
7090 VectorType maskType = createMaskOp.getVectorType();
7092 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7095 constexpr std::array<int64_t, 1> rankZeroShape{1};
7096 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7097 if (maskType.getRank() == 0) {
7098 maskTypeDimSizes = rankZeroShape;
7099 maskTypeDimScalableFlags = rankZeroScalableDims;
7105 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
7110 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7112 constantDims.push_back(*intSize);
7116 if (vscaleMultiplier < maskTypeDimSizes[i])
7118 constantDims.push_back(*vscaleMultiplier);
7125 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7126 value = std::clamp<int64_t>(value, 0, maskDimSize);
7129 if (llvm::is_contained(constantDims, 0))
7130 constantDims.assign(constantDims.size(), 0);
7143 results.
add<CreateMaskFolder>(context);
7154 assert(maskRegionBuilder &&
7155 "builder callback for 'maskRegion' must be present");
7161 maskRegionBuilder(builder, maskableOp);
7168 build(builder, result, resultTypes, mask,
Value(), maskableOp,
7176 build(builder, result, mask, maskableOp, maskRegionBuilder);
7197 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7204 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
7218 result.
types.append(resultTypes);
7224 if (parsePassthru.succeeded()) {
7225 if (resultTypes.empty())
7228 "expects a result if passthru operand is provided");
7238 p <<
" " << getMask();
7240 p <<
", " << getPassthru();
7244 Block *singleBlock = &getMaskRegion().getBlocks().
front();
7251 p <<
" : " << getMask().getType();
7252 if (getNumResults() > 0)
7253 p <<
" -> " << getResultTypes();
7260 MaskOp>::ensureTerminator(region, builder, loc);
7266 if (isa<vector::YieldOp>(block.
back()))
7275 MaskOp>::ensureTerminator(region, builder, loc);
7282 opBuilder.setInsertionPointToEnd(&block);
7283 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7288 Block &block = getMaskRegion().getBlocks().
front();
7290 return emitOpError(
"expects a terminator within the mask region");
7293 if (numMaskRegionOps > 2)
7294 return emitOpError(
"expects only one operation to mask");
7297 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7299 return emitOpError(
"expects a terminator within the mask region");
7301 if (terminator->getNumOperands() != getNumResults())
7303 "expects number of results to match mask region yielded values");
7306 if (numMaskRegionOps == 1)
7309 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7311 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7315 return emitOpError(
"expects number of results to match maskable operation "
7316 "number of results");
7318 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7319 return emitOpError(
"expects all the results from the MaskableOpInterface "
7320 "to match all the values returned by the terminator");
7322 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7324 "expects result type to match maskable operation result type");
7327 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
7328 return emitOpError(
"multiple vector results not supported");
7331 Type expectedMaskType = maskableOp.getExpectedMaskType();
7332 if (getMask().
getType() != expectedMaskType)
7333 return emitOpError(
"expects a ")
7334 << expectedMaskType <<
" mask for the maskable operation";
7337 Value passthru = getPassthru();
7339 if (!maskableOp.supportsPassthru())
7341 "doesn't expect a passthru argument for this maskable operation");
7344 return emitOpError(
"expects result when passthru argument is provided");
7347 return emitOpError(
"expects passthru type to match result type");
7367 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7369 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7372 Block *block = maskOp.getMaskBlock();
7373 auto terminator = cast<vector::YieldOp>(block->
front());
7374 if (terminator.getNumOperands() == 0) {
7380 llvm::append_range(results, terminator.getOperands());
7384 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7386 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7394 Operation *maskableOp = getMaskableOp();
7398 llvm::append_range(results, maskableOp->
getResults());
7419 if (!maskOp.isEmpty())
7422 if (!maskOp.hasPassthru())
7425 Block *block = maskOp.getMaskBlock();
7426 auto terminator = cast<vector::YieldOp>(block->
front());
7427 assert(terminator.getNumOperands() == 1 &&
7428 "expected one result when passthru is provided");
7431 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7432 terminator.getOperand(0), maskOp.getPassthru());
7440 results.
add<CanonializeEmptyMaskOp>(context);
7447 Block *block = getMaskBlock();
7451 return &block->
front();
7455 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
7462 VectorType srcType = getSourceType();
7463 VectorType initialType = getInitialValueType();
7465 int64_t srcRank = srcType.getRank();
7466 int64_t reductionDim = getReductionDim();
7467 if (reductionDim >= srcRank)
7468 return emitOpError(
"reduction dimension ")
7469 << reductionDim <<
" has to be less than " << srcRank;
7472 int64_t initialValueRank = initialType.getRank();
7473 if (initialValueRank != srcRank - 1)
7474 return emitOpError(
"initial value rank ")
7475 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7481 for (
int i = 0; i < srcRank; i++) {
7482 if (i != reductionDim)
7483 expectedShape.push_back(srcShape[i]);
7485 if (!llvm::equal(initialValueShapes, expectedShape)) {
7486 return emitOpError(
"incompatible input/initial value shapes");
7490 Type eltType = getDestType().getElementType();
7492 return emitOpError(
"unsupported reduction type ")
7493 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7502 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7503 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7504 StridedSliceConstantMaskFolder, TransposeFolder>(
7510 arith::FastMathFlagsAttr fastmath,
7517 case CombiningKind::ADD:
7520 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7521 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7523 llvm_unreachable(
"invalid value types for ADD reduction");
7525 case CombiningKind::AND:
7529 case CombiningKind::MAXNUMF:
7530 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7531 "expected float values");
7532 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7534 case CombiningKind::MAXIMUMF:
7535 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7536 "expected float values");
7537 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7539 case CombiningKind::MINNUMF:
7540 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7541 "expected float values");
7542 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7544 case CombiningKind::MINIMUMF:
7545 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7546 "expected float values");
7547 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7549 case CombiningKind::MAXSI:
7553 case CombiningKind::MINSI:
7557 case CombiningKind::MAXUI:
7565 case CombiningKind::MUL:
7568 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7569 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7571 llvm_unreachable(
"invalid value types for MUL reduction");
7573 case CombiningKind::OR:
7577 case CombiningKind::XOR:
7583 assert(result &&
"unknown CombiningKind");
7593 auto resultType = cast<VectorType>(
getType());
7594 if (resultType.isScalable()) {
7597 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7598 APInt zero(bitwidth, 0);
7599 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7601 setResultRanges(getResult(), result);
7612 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7629 return MaskOp::create(builder, maskableOp->
getLoc(),
7632 return MaskOp::create(builder, maskableOp->
getLoc(),
7649 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
7650 mask, newValue, passthru);
7657 #define GET_ATTRDEF_CLASSES
7658 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7660 #define GET_OP_CLASSES
7661 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A set of arbitrary-precision integers representing bounds on a given integer value.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)