34 #include "llvm/ADT/StringSet.h" 35 #include "llvm/ADT/bit.h" 38 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" 40 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" 63 for (
bool b : denseElts.getValues<
bool>())
66 else if (!b && val <= 0)
79 ArrayAttr masks = m.getMaskDimSizes();
80 assert(masks.size() == 1);
81 int64_t i = masks[0].cast<IntegerAttr>().getInt();
82 int64_t u = m.getType().getDimSize(0);
94 switch (combiningKind) {
95 case CombiningKind::ADD:
96 case CombiningKind::MUL:
98 case CombiningKind::MINUI:
99 case CombiningKind::MINSI:
100 case CombiningKind::MAXUI:
101 case CombiningKind::MAXSI:
102 case CombiningKind::AND:
103 case CombiningKind::OR:
104 case CombiningKind::XOR:
106 case CombiningKind::MINF:
107 case CombiningKind::MAXF:
119 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
124 int64_t elementVectorRank = 0;
125 VectorType elementVectorType =
126 shapedType.getElementType().dyn_cast<VectorType>();
127 if (elementVectorType)
128 elementVectorRank += elementVectorType.getRank();
131 if (shapedType.getRank() == 0 &&
137 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
138 shapedType.getContext());
142 vector::TransferReadOp read) {
143 return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
144 !read.getMask() && defWrite.getIndices() == read.getIndices() &&
145 defWrite.getVectorType() == read.getVectorType() &&
146 defWrite.getPermutationMap() == read.getPermutationMap();
150 vector::TransferWriteOp priorWrite) {
151 return priorWrite.getIndices() == write.getIndices() &&
152 priorWrite.getMask() == write.getMask() &&
153 priorWrite.getVectorType() == write.getVectorType() &&
154 priorWrite.getPermutationMap() == write.getPermutationMap();
158 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
160 if (transferA.getVectorType() != transferB.getVectorType())
162 unsigned rankOffset = transferA.getLeadingShapedRank();
163 for (
unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
164 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
165 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
167 if (!indexA || !indexB)
170 if (i < rankOffset) {
173 if (indexA.getValue().cast<IntegerAttr>().getInt() !=
174 indexB.getValue().cast<IntegerAttr>().getInt())
180 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
181 indexB.getValue().cast<IntegerAttr>().getInt());
182 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
190 VectorTransferOpInterface transferB) {
191 if (transferA.source() != transferB.source())
224 return Base::get(context, static_cast<uint64_t>(kind));
228 return static_cast<CombiningKind
>(getImpl()->value);
235 CombiningKind::MINUI,
236 CombiningKind::MINSI,
238 CombiningKind::MAXUI,
239 CombiningKind::MAXSI,
250 return bitEnumContains(this->getKind(), kind);
252 llvm::interleaveComma(kinds, printer,
253 [&](
auto kind) { printer << stringifyEnum(kind); });
265 auto kind = symbolizeCombiningKind(elemName);
284 if (attrKind ==
"kind")
291 void VectorDialect::printAttribute(
Attribute attr,
298 llvm_unreachable(
"Unknown attribute type");
305 void VectorDialect::initialize() {
306 addAttributes<CombiningKindAttr>();
310 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" 319 return builder.
create<arith::ConstantOp>(loc, type,
value);
335 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
338 CombiningKind kind) {
342 reductionDims.push_back(en.index());
343 build(builder, result, kind, source, acc,
349 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
355 return llvm::to_vector<4>(getSourceVectorType().getShape());
360 Type inferredReturnType;
362 if (!llvm::any_of(getReductionDims().getValue(), [&](
Attribute attr) {
363 return attr.
cast<IntegerAttr>().getValue() == it.index();
365 targetShape.push_back(it.value());
367 if (targetShape.empty())
368 inferredReturnType = getSourceVectorType().getElementType();
371 VectorType::get(targetShape, getSourceVectorType().
getElementType());
372 if (getType() != inferredReturnType)
373 return emitOpError() <<
"destination type " << getType()
374 <<
" is incompatible with source type " 375 << getSourceVectorType();
385 CombiningKind kind,
Value vector) {
386 build(builder, result, kind, vector,
Value());
390 CombiningKind kind,
Value vector,
Value acc) {
399 return emitOpError(
"unsupported reduction rank: ") << rank;
402 Type eltType = getDest().getType();
404 return emitOpError(
"unsupported reduction type '")
405 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
421 (!operandsInfo.empty() &&
423 (operandsInfo.size() > 1 &&
427 if (operandsInfo.empty() || operandsInfo.size() > 2)
429 "unsupported number of operands");
435 getKindAttr().print(p);
436 p <<
", " << getVector();
438 p <<
", " << getAcc();
439 p <<
" : " << getVector().getType() <<
" into " << getDest().getType();
446 case arith::AtomicRMWKind::addf:
447 case arith::AtomicRMWKind::addi:
448 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
449 CombiningKind::ADD, vector);
450 case arith::AtomicRMWKind::mulf:
451 case arith::AtomicRMWKind::muli:
452 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
453 CombiningKind::MUL, vector);
454 case arith::AtomicRMWKind::minf:
455 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
456 CombiningKind::MINF, vector);
457 case arith::AtomicRMWKind::mins:
458 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
459 CombiningKind::MINSI, vector);
460 case arith::AtomicRMWKind::minu:
461 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
462 CombiningKind::MINUI, vector);
463 case arith::AtomicRMWKind::maxf:
464 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
465 CombiningKind::MAXF, vector);
466 case arith::AtomicRMWKind::maxs:
467 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
468 CombiningKind::MAXSI, vector);
469 case arith::AtomicRMWKind::maxu:
470 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
471 CombiningKind::MAXUI, vector);
472 case arith::AtomicRMWKind::andi:
473 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
474 CombiningKind::AND, vector);
475 case arith::AtomicRMWKind::ori:
476 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
477 CombiningKind::OR, vector);
496 if (reductionOp.getVectorType().getDimSize(0) != 1)
499 Location loc = reductionOp.getLoc();
500 Value result = rewriter.
create<ExtractOp>(loc, reductionOp.getType(),
501 reductionOp.getVector(),
504 if (
Value acc = reductionOp.getAcc())
516 results.
add<ElideSingleElementReduction>(context);
538 ArrayAttr indexingMaps,
539 ArrayAttr iteratorTypes) {
540 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
541 ContractionOp::getDefaultKind());
546 ArrayAttr indexingMaps,
547 ArrayAttr iteratorTypes, CombiningKind kind) {
552 result.
addAttribute(ContractionOp::getKindAttrStrName(),
564 DictionaryAttr dictAttr;
580 dictAttr.getValue().end());
581 if (!result.
attributes.
get(ContractionOp::getKindAttrStrName())) {
582 result.
addAttribute(ContractionOp::getKindAttrStrName(),
586 if (masksInfo.empty())
588 if (masksInfo.size() != 2)
590 "expected zero or exactly 2 vector mask operands");
591 auto lhsType = types[0].cast<VectorType>();
592 auto rhsType = types[1].cast<VectorType>();
594 std::array<Type, 2> maskTypes = {
604 auto attrNames = getTraitAttrNames();
606 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
608 for (
auto attr : (*this)->getAttrs())
609 if (traitAttrsSet.count(attr.getName().strref()) > 0)
610 attrs.push_back(attr);
612 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
613 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
614 p << getRhs() <<
", " << getAcc();
615 if (getMasks().size() == 2)
616 p <<
", " << getMasks();
619 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into " 624 const std::vector<std::pair<int64_t, int64_t>> &map) {
625 for (
auto &dimPair : map) {
626 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
627 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
628 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
635 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
637 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
638 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
641 for (
auto &dimPair : contractingDimMap) {
642 lhsContractingDimSet.insert(dimPair.first);
643 rhsContractingDimSet.insert(dimPair.second);
646 for (
auto &dimPair : batchDimMap)
647 rhsBatchDimSet.insert(dimPair.second);
650 SmallVector<int64_t, 4> expectedResultDims;
651 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
652 if (lhsContractingDimSet.count(i) > 0)
654 expectedResultDims.push_back(lhsType.getDimSize(i));
658 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
659 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
661 expectedResultDims.push_back(rhsType.getDimSize(i));
665 if (expectedResultDims.empty()) {
667 if (resType.
isa<VectorType>() || accType.
isa<VectorType>())
668 return op.emitOpError(
"invalid accumulator/result vector shape");
671 auto resVectorType = resType.
dyn_cast<VectorType>();
672 auto accVectorType = accType.
dyn_cast<VectorType>();
673 if (!resVectorType || !accVectorType)
674 return op.emitOpError(
"invalid accumulator/result vector shape");
680 AffineMap lhsMap = op.getIndexingMapsArray()[0];
681 AffineMap rhsMap = op.getIndexingMapsArray()[1];
683 return op.emitOpError(
684 "expected all dimensions to be either a LHS or a RHS dimension");
685 SmallVector<AffineExpr, 4> extents(lhsMap.
getNumInputs());
687 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
688 VectorType v = pair.first;
689 auto map = pair.second;
690 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
696 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
697 return op.emitOpError(
"expected all dimensions to get an extent as " 698 "either a LHS or a RHS dimension");
700 AffineMap resMap = op.getIndexingMapsArray()[2];
706 expectedMap.getResults(),
708 "expected constant extent along all dimensions.");
710 auto expectedShape = llvm::to_vector<4>(
711 llvm::map_range(expectedMap.getResults(), [](
AffineExpr e) {
715 VectorType::get(expectedShape, resVectorType.getElementType());
716 if (resVectorType != expected || accVectorType != expected)
717 return op.emitOpError(
718 "invalid accumulator/result vector shape, expected: ")
725 auto lhsType = getLhsType();
726 auto rhsType = getRhsType();
727 auto accType = getAccType();
728 auto resType = getResultType();
731 if (getIndexingMapsArray().size() != 3)
732 return emitOpError(
"expected an indexing map for each vector operand");
737 unsigned numIterators = getIteratorTypes().getValue().size();
739 auto index = it.index();
740 auto map = it.value();
741 if (map.getNumSymbols() != 0)
742 return emitOpError(
"expected indexing map ")
743 << index <<
" to have no symbols";
744 auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
748 if (map.getNumDims() != numIterators)
749 return emitOpError(
"expected indexing map ")
750 << index <<
" to have " << numIterators <<
" number of inputs";
751 if (map.getNumResults() != rank)
752 return emitOpError(
"expected indexing map ")
753 << index <<
" to have " << rank <<
" number of outputs";
754 if (!map.isProjectedPermutation())
755 return emitOpError(
"expected indexing map ")
756 << index <<
" to be a projected permutation of its inputs";
759 auto contractingDimMap = getContractingDimMap();
760 auto batchDimMap = getBatchDimMap();
763 if (contractingDimMap.empty())
764 return emitOpError(
"expected at least one contracting dimension pair");
768 return emitOpError(
"invalid contracting dimension map");
772 return emitOpError(
"invalid batch dimension map");
776 contractingDimMap, batchDimMap)))
780 auto lhsMaskType = getLHSVectorMaskType();
781 auto rhsMaskType = getRHSVectorMaskType();
782 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
783 return emitOpError(
"invalid number of vector masks specified");
784 if (lhsMaskType && rhsMaskType) {
786 if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
787 rhsMaskType.getShape().size() != rhsType.getShape().size())
788 return emitOpError(
"invalid vector mask rank");
792 auto vectorType = resType.dyn_cast<VectorType>();
795 return emitOpError(
"unsupported contraction type");
800 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
803 ContractionOp::getKindAttrStrName()};
804 return llvm::makeArrayRef(names);
814 static std::vector<std::pair<int64_t, int64_t>>
815 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
816 StringRef targetIteratorTypeName,
MLIRContext *context) {
817 std::vector<std::pair<int64_t, int64_t>> dimMap;
819 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
820 if (iteratorTypeName != targetIteratorTypeName)
826 if (lhsDim >= 0 && rhsDim >= 0)
827 dimMap.emplace_back(lhsDim, rhsDim);
832 void ContractionOp::getIterationBounds(
834 auto lhsShape = getLhsType().getShape();
835 auto resVectorType = getResultType().dyn_cast<VectorType>();
836 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
837 SmallVector<int64_t, 2> iterationShape;
841 auto iteratorTypeName = it.value().
cast<StringAttr>().getValue();
844 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
845 assert(lhsDimIndex >= 0);
846 iterationBounds.push_back(lhsShape[lhsDimIndex]);
850 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
851 assert(resDimIndex >= 0);
852 assert(resVectorType !=
nullptr);
853 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
857 void ContractionOp::getIterationIndexMap(
859 unsigned numMaps = getIndexingMapsArray().size();
860 iterationIndexMap.resize(numMaps);
862 auto index = it.index();
863 auto map = it.value();
864 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
866 iterationIndexMap[index][dim.getPosition()] = i;
871 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
872 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
873 return getDimMap(indexingMaps, getIteratorTypes(),
877 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
878 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
879 return getDimMap(indexingMaps, getIteratorTypes(),
884 SmallVector<int64_t, 4> shape;
885 getIterationBounds(shape);
907 template <
typename AddOpType>
913 auto canonicalize = [&](
Value maybeContraction,
914 Value otherOperand) -> vector::ContractionOp {
915 vector::ContractionOp contractionOp =
916 dyn_cast_or_null<vector::ContractionOp>(
919 return vector::ContractionOp();
920 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
921 contractionOp.getAcc().getDefiningOp())) {
922 if (maybeZero.getValue() ==
923 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
925 bvm.
map(contractionOp.getAcc(), otherOperand);
926 auto newContraction =
927 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
928 rewriter.
replaceOp(addOp, newContraction.getResult());
929 return newContraction;
932 return vector::ContractionOp();
935 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
936 vector::ContractionOp
contract = canonicalize(a, b);
937 contract = contract ?
contract : canonicalize(b, a);
966 if (vectorType.getRank() == 0) {
968 return emitOpError(
"expected position to be empty with 0-D vector");
971 if (vectorType.getRank() != 1)
972 return emitOpError(
"unexpected >1 vector rank");
974 return emitOpError(
"expected position for 1-D vector");
980 if (operands.size() < 2)
987 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
988 return splat.getInput();
995 auto attr = pos.
dyn_cast<IntegerAttr>();
996 uint64_t posIdx = attr.getInt();
998 return srcElements[posIdx];
1014 llvm::to_vector<4>(llvm::map_range(position, [](
Value pos) {
1017 build(builder, result, source, positionConstants);
1022 ValueRange operands, DictionaryAttr attributes,
1025 ExtractOp::Adaptor op(operands, attributes);
1026 auto vectorType = op.getVector().getType().cast<VectorType>();
1027 if (static_cast<int64_t>(op.getPosition().size()) ==
vectorType.getRank()) {
1028 inferredReturnTypes.push_back(
vectorType.getElementType());
1031 std::min<size_t>(op.getPosition().size(),
vectorType.getRank() - 1);
1032 inferredReturnTypes.push_back(VectorType::get(
1041 auto vectorType = l.front().dyn_cast<VectorType>();
1045 if (l.size() == 1 && r.size() == 1 &&
1046 (isCompatible(l, r) || isCompatible(r, l)))
1052 auto positionAttr = getPosition().getValue();
1053 if (positionAttr.size() >
static_cast<unsigned>(
getVectorType().getRank()))
1055 "expected position attribute of rank smaller than vector rank");
1057 auto attr = en.value().
dyn_cast<IntegerAttr>();
1058 if (!attr || attr.getInt() < 0 ||
1060 return emitOpError(
"expected position attribute #")
1062 <<
" to be a non-negative integer smaller than the corresponding " 1068 template <
typename IntType>
1070 return llvm::to_vector<4>(llvm::map_range(
1071 arrayAttr.getAsRange<IntegerAttr>(),
1072 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1078 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1081 SmallVector<int64_t, 4> globalPosition;
1082 ExtractOp currentOp = extractOp;
1083 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1084 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1085 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1087 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1088 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1090 extractOp.setOperand(currentOp.getVector());
1093 std::reverse(globalPosition.begin(), globalPosition.end());
1094 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1095 b.getI64ArrayAttr(globalPosition));
1107 class ExtractFromInsertTransposeChainState {
1109 ExtractFromInsertTransposeChainState(ExtractOp e);
1118 template <
typename ContainerA,
typename ContainerB>
1119 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1120 return a.size() <= b.size() &&
1121 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1128 template <
typename ContainerA,
typename ContainerB>
1129 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1130 for (
auto it : llvm::zip(a, b)) {
1131 if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
1133 if (std::get<0>(it) != std::get<1>(it))
1142 return (sentinels ==
1147 void updateStateForNextIteration(
Value v) {
1177 Value tryToFoldExtractOpInPlace(
Value source);
1179 ExtractOp extractOp;
1181 int64_t extractedRank;
1183 InsertOp nextInsertOp;
1184 TransposeOp nextTransposeOp;
1194 SmallVector<int64_t> sentinels;
1199 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1201 : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
1202 extractedRank(extractOp.getPosition().size()) {
1203 assert(vectorRank >= extractedRank &&
"extracted pos overflow");
1204 sentinels.reserve(vectorRank - extractedRank);
1205 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1206 sentinels.push_back(-(i + 1));
1213 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1214 if (!nextTransposeOp)
1216 auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
1225 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1227 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1228 if (makeArrayRef(insertedPos) !=
1232 res = nextInsertOp.getSource();
1241 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1242 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1252 res = nextInsertOp.getSource();
1260 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1263 bool nothingToFold = (source == extractOp.getVector());
1264 if (nothingToFold || !canFold())
1269 extractOp.getPositionAttrName(),
1272 extractOp.getVectorMutable().assign(source);
1273 return extractOp.getResult();
1277 Value ExtractFromInsertTransposeChainState::fold() {
1278 Value valueToExtractFrom = extractOp.getVector();
1279 updateStateForNextIteration(valueToExtractFrom);
1280 while (nextInsertOp || nextTransposeOp) {
1284 valueToExtractFrom = nextTransposeOp.getVector();
1285 updateStateForNextIteration(valueToExtractFrom);
1291 if (
succeeded(handleInsertOpWithMatchingPos(result)))
1296 if (
succeeded(handleInsertOpWithPrefixPos(result)))
1297 return tryToFoldExtractOpInPlace(result);
1301 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1307 valueToExtractFrom = nextInsertOp.getDest();
1308 updateStateForNextIteration(valueToExtractFrom);
1311 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1316 Operation *defOp = extractOp.getVector().getDefiningOp();
1317 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1320 if (extractOp.getType() == source.
getType())
1322 auto getRank = [](
Type type) {
1323 return type.
isa<VectorType>() ? type.
cast<VectorType>().getRank() : 0;
1325 unsigned broadcastSrcRank = getRank(source.
getType());
1326 unsigned extractResultRank = getRank(extractOp.getType());
1327 if (extractResultRank >= broadcastSrcRank)
1330 auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
1332 if (extractVecType && broadcastVecType &&
1333 extractVecType.getShape() !=
1334 broadcastVecType.getShape().take_back(extractResultRank))
1336 auto extractPos = extractVector<int64_t>(extractOp.getPosition());
1337 unsigned rankDiff = broadcastSrcRank - extractResultRank;
1338 extractPos.erase(extractPos.begin(),
1339 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1340 extractOp.setOperand(source);
1343 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1344 b.getI64ArrayAttr(extractPos));
1345 return extractOp.getResult();
1350 auto shapeCastOp = extractOp.getVector().
getDefiningOp<vector::ShapeCastOp>();
1354 auto getDimReverse = [](VectorType type, int64_t n) {
1355 return type.getShape().take_back(n + 1).front();
1357 int64_t destinationRank =
1358 extractOp.getType().isa<VectorType>()
1359 ? extractOp.getType().cast<VectorType>().getRank()
1361 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1363 if (destinationRank > 0) {
1364 auto destinationType = extractOp.getResult().
getType().
cast<VectorType>();
1365 for (int64_t i = 0; i < destinationRank; i++) {
1369 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1370 getDimReverse(destinationType, i))
1376 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1377 std::reverse(extractedPos.begin(), extractedPos.end());
1378 SmallVector<int64_t, 4> strides;
1380 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1381 strides.push_back(stride);
1382 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1385 int64_t position =
linearize(extractedPos, strides);
1388 SmallVector<int64_t, 4> newStrides;
1389 int64_t numDimension =
1390 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1392 for (int64_t i = 0; i < numDimension; i++) {
1393 newStrides.push_back(stride);
1395 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1397 std::reverse(newStrides.begin(), newStrides.end());
1398 SmallVector<int64_t, 4> newPosition =
delinearize(newStrides, position);
1401 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1402 b.getI64ArrayAttr(newPosition));
1403 extractOp.setOperand(shapeCastOp.getSource());
1404 return extractOp.getResult();
1409 auto extractStridedSliceOp =
1410 extractOp.getVector().
getDefiningOp<vector::ExtractStridedSliceOp>();
1411 if (!extractStridedSliceOp)
1414 if (extractStridedSliceOp.hasNonUnitStrides())
1419 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1420 while (!sliceOffsets.empty()) {
1421 size_t lastOffset = sliceOffsets.size() - 1;
1422 if (sliceOffsets.back() != 0 ||
1423 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1424 extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
1426 sliceOffsets.pop_back();
1428 unsigned destinationRank = 0;
1429 if (
auto vecType = extractOp.getType().dyn_cast<VectorType>())
1430 destinationRank = vecType.getRank();
1433 if (destinationRank >
1434 extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
1436 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1437 assert(extractedPos.size() >= sliceOffsets.size());
1438 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1439 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1440 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1443 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1444 b.getI64ArrayAttr(extractedPos));
1445 return extractOp.getResult();
1450 int64_t destinationRank = op.
getType().
isa<VectorType>()
1451 ? op.getType().cast<VectorType>().getRank()
1453 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
1455 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1456 insertOp.getSourceVectorType().getRank();
1457 if (destinationRank > insertOp.getSourceVectorType().getRank())
1459 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1460 auto extractOffsets = extractVector<int64_t>(op.getPosition());
1462 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1463 return attr.
cast<IntegerAttr>().getInt() != 1;
1466 bool disjoint =
false;
1467 SmallVector<int64_t, 4> offsetDiffs;
1468 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1469 int64_t start = insertOffsets[dim];
1471 (dim < insertRankDiff)
1473 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1474 int64_t end = start + size;
1475 int64_t offset = extractOffsets[dim];
1477 if (start <= offset && offset < end) {
1478 if (dim >= insertRankDiff)
1479 offsetDiffs.push_back(offset - start);
1489 int64_t srcRankDiff =
1490 insertOp.getSourceVectorType().getRank() - destinationRank;
1491 for (int64_t i = 0; i < destinationRank; i++) {
1492 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1493 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1497 op.getVectorMutable().assign(insertOp.getSource());
1500 op->setAttr(ExtractOp::getPositionAttrStrName(),
1501 b.getI64ArrayAttr(offsetDiffs));
1502 return op.getResult();
1506 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1512 if (getPosition().empty())
1516 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
1538 Operation *defOp = extractOp.getVector().getDefiningOp();
1539 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1543 if (extractOp.getType() == source.
getType())
1545 auto getRank = [](
Type type) {
1546 return type.
isa<VectorType>() ? type.
cast<VectorType>().getRank() : 0;
1548 unsigned broadcastSrcRank = getRank(source.
getType());
1549 unsigned extractResultRank = getRank(extractOp.getType());
1553 if (extractResultRank < broadcastSrcRank)
1556 extractOp, extractOp.getType(), source);
1570 auto constantOp = extractOp.getVector().getDefiningOp<arith::ConstantOp>();
1577 if (
auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1588 results.
add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
1593 for (
auto attr : arrayAttr)
1594 results.push_back(attr.
cast<IntegerAttr>().getInt());
1603 ArrayRef<int64_t> multiplicity,
1605 assert(ids.size() == multiplicity.size() &&
1608 VectorType type = vector.
getType().
cast<VectorType>();
1609 SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1610 type.getShape().end());
1611 for (
unsigned i = 0, e = permutationMap.
getNumResults(); i < e; i++) {
1614 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1616 VectorType resultType = VectorType::get(newShape, type.getElementType());
1617 ExtractMapOp::build(builder, result, resultType, vector, ids);
1621 if (getSourceVectorType().getRank() != getResultType().getRank())
1622 return emitOpError(
"expected source and destination vectors of same rank");
1624 for (
unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) {
1625 if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) !=
1627 return emitOpError(
"source vector dimensions must be a multiple of " 1628 "destination vector dimensions");
1629 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1632 if (numId != getIds().size())
1633 return emitOpError(
"expected number of ids must match the number of " 1634 "dimensions distributed");
1638 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1639 auto insert = getVector().getDefiningOp<vector::InsertMapOp>();
1640 if (insert ==
nullptr || getType() != insert.getVector().getType() ||
1641 getIds() != insert.getIds())
1643 return insert.getVector();
1647 assert(multiplicity.empty());
1648 for (
unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1649 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1650 multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1651 getResultType().getDimSize(i));
1655 template <
typename MapOp>
1657 SmallVector<AffineExpr, 4>
perm;
1660 for (
unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1661 if (op.getSourceVectorType().getDimSize(i) !=
1662 op.getResultType().getDimSize(i))
1686 std::pair<int, int> *mismatchingDims) {
1692 VectorType srcVectorType = srcType.
dyn_cast<VectorType>();
1696 int64_t srcRank = srcVectorType.getRank();
1697 int64_t dstRank = dstVectorType.getRank();
1698 if (srcRank > dstRank)
1702 int64_t lead = dstRank - srcRank;
1703 for (int64_t r = 0; r < srcRank; ++r) {
1704 int64_t srcDim = srcVectorType.getDimSize(r);
1705 int64_t dstDim = dstVectorType.getDimSize(lead + r);
1706 if (srcDim != 1 && srcDim != dstDim) {
1707 if (mismatchingDims) {
1708 mismatchingDims->first = srcDim;
1709 mismatchingDims->second = dstDim;
1719 std::pair<int, int> mismatchingDims;
1725 return emitOpError(
"source rank higher than destination rank");
1727 return emitOpError(
"dimension mismatch (")
1728 << mismatchingDims.first <<
" vs. " << mismatchingDims.second <<
")";
1730 return emitOpError(
"source type is not a vector");
1731 llvm_unreachable(
"unexpected vector.broadcast op error");
1740 if (operands[0].isa<IntegerAttr, FloatAttr>())
1742 if (
auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1755 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
1759 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
1769 results.
add<BroadcastFolder>(context);
1783 VectorType v1Type = getV1VectorType();
1784 VectorType v2Type = getV2VectorType();
1786 int64_t resRank = resultType.getRank();
1787 int64_t v1Rank = v1Type.getRank();
1788 int64_t v2Rank = v2Type.getRank();
1789 if (resRank != v1Rank || v1Rank != v2Rank)
1790 return emitOpError(
"rank mismatch");
1792 for (int64_t r = 1; r < v1Rank; ++r) {
1793 int64_t resDim = resultType.getDimSize(r);
1794 int64_t v1Dim = v1Type.getDimSize(r);
1795 int64_t v2Dim = v2Type.getDimSize(r);
1796 if (resDim != v1Dim || v1Dim != v2Dim)
1797 return emitOpError(
"dimension mismatch");
1800 auto maskAttr = getMask().getValue();
1801 int64_t maskLength = maskAttr.size();
1802 if (maskLength <= 0)
1803 return emitOpError(
"invalid mask length");
1804 if (maskLength != resultType.getDimSize(0))
1805 return emitOpError(
"mask length mismatch");
1807 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1809 auto attr = en.value().
dyn_cast<IntegerAttr>();
1810 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1811 return emitOpError(
"mask index #") << (en.index() + 1) <<
" out of range";
1818 ValueRange operands, DictionaryAttr attributes,
1821 ShuffleOp::Adaptor op(operands, attributes);
1822 auto v1Type = op.getV1().getType().cast<VectorType>();
1826 shape.reserve(v1Type.getRank());
1827 shape.push_back(std::max<size_t>(1, op.getMask().size()));
1828 llvm::append_range(shape, v1Type.getShape().drop_front());
1829 inferredReturnTypes.push_back(
1830 VectorType::get(shape, v1Type.getElementType()));
1835 uint64_t expected = begin;
1836 return idxArr.size() == width &&
1837 llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
1838 [&expected](
auto attr) {
1839 return attr.getZExtValue() == expected++;
1843 OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
1845 if (!getV1VectorType().isScalable() &&
1849 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
1851 getV2VectorType().getDimSize(0)))
1854 Attribute lhs = operands.front(), rhs = operands.back();
1861 if (lhsType.getRank() != 1)
1863 int64_t lhsSize = lhsType.getDimSize(0);
1865 SmallVector<Attribute> results;
1868 for (
const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
1869 int64_t i = index.getZExtValue();
1871 results.push_back(rhsElements[i - lhsSize]);
1873 results.push_back(lhsElements[i]);
1889 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
1890 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
1892 if (!v1Splat || !v2Splat)
1895 if (v1Splat.getInput() != v2Splat.getInput())
1907 results.
add<ShuffleSplat>(context);
1916 build(builder, result, source, dest, {});
1920 auto dstVectorType = getDestVectorType();
1921 if (dstVectorType.getRank() == 0) {
1923 return emitOpError(
"expected position to be empty with 0-D vector");
1926 if (dstVectorType.getRank() != 1)
1927 return emitOpError(
"unexpected >1 vector rank");
1929 return emitOpError(
"expected position for 1-D vector");
1933 OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
1935 if (operands.size() < 3)
1941 if (!src || !dst || !pos)
1946 SmallVector<Attribute> results(dstElements);
1948 auto attr = pos.
dyn_cast<IntegerAttr>();
1949 uint64_t posIdx = attr.getInt();
1951 results[posIdx] = src;
1961 Value dest, ArrayRef<int64_t> position) {
1965 result.
addAttribute(getPositionAttrStrName(), positionAttr);
1971 SmallVector<int64_t, 4> positionConstants =
1972 llvm::to_vector<4>(llvm::map_range(position, [](
Value pos) {
1975 build(builder, result, source, dest, positionConstants);
1979 auto positionAttr = getPosition().getValue();
1980 auto destVectorType = getDestVectorType();
1981 if (positionAttr.size() >
static_cast<unsigned>(destVectorType.getRank()))
1983 "expected position attribute of rank smaller than dest vector rank");
1984 auto srcVectorType = getSourceType().dyn_cast<VectorType>();
1985 if (srcVectorType &&
1986 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1987 static_cast<unsigned>(destVectorType.getRank())))
1988 return emitOpError(
"expected position attribute rank + source rank to " 1989 "match dest vector rank");
1990 if (!srcVectorType &&
1991 (positionAttr.size() !=
static_cast<unsigned>(destVectorType.getRank())))
1993 "expected position attribute rank to match the dest vector rank");
1995 auto attr = en.value().
dyn_cast<IntegerAttr>();
1996 if (!attr || attr.getInt() < 0 ||
1997 attr.getInt() >= destVectorType.getDimSize(en.index()))
1998 return emitOpError(
"expected position attribute #")
2000 <<
" to be a non-negative integer smaller than the corresponding " 2001 "dest vector dimension";
2016 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
2017 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2018 srcVecType.getNumElements())
2021 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2033 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2034 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2036 if (!srcSplat || !dstSplat)
2039 if (srcSplat.getInput() != dstSplat.getInput())
2051 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
2057 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
2058 if (getPosition().empty())
2068 if (getSourceVectorType().getRank() != getResultType().getRank())
2069 return emitOpError(
"expected source and destination vectors of same rank");
2071 for (
unsigned i = 0, e = getResultType().getRank(); i < e; i++) {
2072 if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) !=
2075 "destination vector size must be a multiple of source vector size");
2076 if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i))
2079 if (numId != getIds().size())
2080 return emitOpError(
"expected number of ids must match the number of " 2081 "dimensions distributed");
2093 ArrayRef<int64_t> offsets,
2094 ArrayRef<int64_t> strides) {
2099 result.
addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2100 result.
addAttribute(getStridesAttrStrName(), stridesAttr);
2104 template <
typename OpType>
2106 ArrayAttr arrayAttr,
2107 ArrayRef<int64_t> shape,
2108 StringRef attrName) {
2109 if (arrayAttr.size() > shape.size())
2110 return op.emitOpError(
"expected ")
2111 << attrName <<
" attribute of rank smaller than vector rank";
2118 template <
typename OpType>
2121 int64_t
max, StringRef attrName,
2122 bool halfOpen =
true) {
2123 for (
auto attr : arrayAttr) {
2124 auto val = attr.
cast<IntegerAttr>().getInt();
2128 if (val < min || val >= upper)
2129 return op.emitOpError(
"expected ") << attrName <<
" to be confined to [" 2130 << min <<
", " << upper <<
")";
2138 template <
typename OpType>
2141 ArrayRef<int64_t> shape, StringRef attrName,
2142 bool halfOpen =
true, int64_t
min = 0) {
2143 assert(arrayAttr.size() <= shape.size());
2145 for (
auto it : llvm::zip(arrayAttr, shape)) {
2146 auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
2147 auto max = std::get<1>(it);
2150 if (val < min || val >=
max)
2151 return op.emitOpError(
"expected ")
2152 << attrName <<
" dimension " << index <<
" to be confined to [" 2153 <<
min <<
", " <<
max <<
")";
2162 template <
typename OpType>
2164 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2165 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2166 bool halfOpen =
true, int64_t
min = 1) {
2167 assert(arrayAttr1.size() <= shape.size());
2168 assert(arrayAttr2.size() <= shape.size());
2170 for (
auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
2171 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
2172 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
2173 auto max = std::get<2>(it);
2176 if (val1 + val2 < 0 || val1 + val2 >=
max)
2177 return op.emitOpError(
"expected sum(")
2178 << attrName1 <<
", " << attrName2 <<
") dimension " << index
2179 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
2187 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
2188 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2190 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2194 auto sourceVectorType = getSourceVectorType();
2195 auto destVectorType = getDestVectorType();
2196 auto offsets = getOffsetsAttr();
2197 auto strides = getStridesAttr();
2198 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
2200 "expected offsets of same size as destination vector rank");
2201 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
2202 return emitOpError(
"expected strides of same size as source vector rank");
2203 if (sourceVectorType.getRank() > destVectorType.getRank())
2205 "expected source rank to be smaller than destination rank");
2207 auto sourceShape = sourceVectorType.getShape();
2208 auto destShape = destVectorType.getShape();
2209 SmallVector<int64_t, 4> sourceShapeAsDestShape(
2210 destShape.size() - sourceShape.size(), 0);
2211 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2212 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2222 offName,
"source vector shape",
2232 class FoldInsertStridedSliceSplat final
2237 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2240 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2242 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2244 if (!srcSplatOp || !destSplatOp)
2247 if (srcSplatOp.getInput() != destSplatOp.getInput())
2250 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2257 class FoldInsertStridedSliceOfExtract final
2262 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2264 auto extractStridedSliceOp =
2265 insertStridedSliceOp.getSource()
2266 .getDefiningOp<vector::ExtractStridedSliceOp>();
2268 if (!extractStridedSliceOp)
2271 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2275 if (extractStridedSliceOp.getStrides() !=
2276 insertStridedSliceOp.getStrides() ||
2277 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2280 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2287 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
2289 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
2293 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2294 if (getSourceVectorType() == getDestVectorType())
2311 p <<
" " << getLhs() <<
", " << getRhs();
2312 if (!getAcc().empty()) {
2313 p <<
", " << getAcc();
2316 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
2320 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
2327 if (operandsInfo.size() < 2)
2329 "expected at least 2 operands");
2330 VectorType vLHS = tLHS.
dyn_cast<VectorType>();
2331 VectorType vRHS = tRHS.
dyn_cast<VectorType>();
2334 "expected vector type for operand #1");
2335 VectorType resType =
2336 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2337 vLHS.getElementType())
2338 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2340 if (!result.
attributes.
get(OuterProductOp::getKindAttrStrName())) {
2342 OuterProductOp::getKindAttrStrName(),
2350 (operandsInfo.size() > 2 &&
2356 Type tRHS = getOperandTypeRHS();
2357 VectorType vLHS = getOperandVectorTypeLHS(),
2358 vRHS = tRHS.
dyn_cast<VectorType>(),
2361 if (vLHS.getRank() != 1)
2362 return emitOpError(
"expected 1-d vector for operand #1");
2366 if (vRHS.getRank() != 1)
2367 return emitOpError(
"expected 1-d vector for operand #2");
2368 if (vRES.getRank() != 2)
2369 return emitOpError(
"expected 2-d vector result");
2370 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2371 return emitOpError(
"expected #1 operand dim to match result dim #1");
2372 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2373 return emitOpError(
"expected #2 operand dim to match result dim #2");
2376 if (vRES.getRank() != 1)
2377 return emitOpError(
"expected 1-d vector result");
2378 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2379 return emitOpError(
"expected #1 operand dim to match result dim #1");
2382 if (vACC && vACC != vRES)
2383 return emitOpError(
"expected operand #3 of same type as result type");
2387 return emitOpError(
"unsupported outerproduct type");
2398 auto inputVectorType = getInputVectorType();
2399 auto outputVectorType = getOutputVectorType();
2400 int64_t inputShapeRank = getNumInputShapeSizes();
2401 int64_t outputShapeRank = getNumOutputShapeSizes();
2402 SmallVector<int64_t, 4> fixedVectorSizes;
2403 getFixedVectorSizes(fixedVectorSizes);
2404 int64_t numFixedVectorSizes = fixedVectorSizes.size();
2406 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2407 return emitError(
"invalid input shape for vector type ") << inputVectorType;
2409 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2410 return emitError(
"invalid output shape for vector type ")
2411 << outputVectorType;
2415 unsigned inputVectorRank = inputVectorType.getRank();
2416 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
2417 unsigned index = inputVectorRank - numFixedVectorSizes - i;
2418 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2419 return emitError(
"fixed vector size must match input vector for dim ")
2423 unsigned outputVectorRank = outputVectorType.getRank();
2424 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
2425 unsigned index = outputVectorRank - numFixedVectorSizes - i;
2426 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2427 return emitError(
"fixed vector size must match output vector for dim ")
2433 auto isDefByConstant = [](
Value operand) {
2434 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2436 if (llvm::all_of(getInputShape(), isDefByConstant) &&
2437 llvm::all_of(getOutputShape(), isDefByConstant)) {
2438 int64_t numInputElements = 1;
2439 for (
auto operand : getInputShape())
2441 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).
value();
2442 int64_t numOutputElements = 1;
2443 for (
auto operand : getOutputShape())
2444 numOutputElements *=
2445 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).
value();
2446 if (numInputElements != numOutputElements)
2447 return emitError(
"product of input and output shape sizes must match");
2464 ArrayAttr offsets, ArrayAttr sizes,
2465 ArrayAttr strides) {
2466 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
2467 SmallVector<int64_t, 4> shape;
2468 shape.reserve(vectorType.getRank());
2470 for (
unsigned e = offsets.size(); idx < e; ++idx)
2471 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
2472 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2473 shape.push_back(vectorType.getShape()[idx]);
2475 return VectorType::get(shape, vectorType.getElementType());
2479 Value source, ArrayRef<int64_t> offsets,
2480 ArrayRef<int64_t> sizes,
2481 ArrayRef<int64_t> strides) {
2488 offsetsAttr, sizesAttr, stridesAttr));
2489 result.
addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2491 result.
addAttribute(getStridesAttrStrName(), stridesAttr);
2496 auto offsets = getOffsetsAttr();
2497 auto sizes = getSizesAttr();
2498 auto strides = getStridesAttr();
2499 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
2501 "expected offsets, sizes and strides attributes of same size");
2503 auto shape = type.getShape();
2504 auto offName = getOffsetsAttrName();
2505 auto sizesName = getSizesAttrName();
2522 shape, offName, sizesName,
2528 if (getResult().getType() != resultType)
2529 return emitOpError(
"expected result type to be ") << resultType;
2540 auto getElement = [](ArrayAttr array,
int idx) {
2541 return array[idx].
cast<IntegerAttr>().getInt();
2543 ArrayAttr extractOffsets = op.getOffsets();
2545 ArrayAttr extractSizes = op.getSizes();
2546 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
2548 if (op.getVectorType().getRank() !=
2549 insertOp.getSourceVectorType().getRank())
2551 ArrayAttr insertOffsets = insertOp.getOffsets();
2552 ArrayAttr insertStrides = insertOp.getStrides();
2555 if (extractOffsets.size() > insertOffsets.size())
2557 bool patialoverlap =
false;
2558 bool disjoint =
false;
2559 SmallVector<int64_t, 4> offsetDiffs;
2560 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2561 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2563 int64_t start = getElement(insertOffsets, dim);
2564 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2565 int64_t offset = getElement(extractOffsets, dim);
2566 int64_t size = getElement(extractSizes, dim);
2568 if (start <= offset && offset < end) {
2571 if (offset + size > end)
2572 patialoverlap =
true;
2573 offsetDiffs.push_back(offset - start);
2580 if (!disjoint && !patialoverlap) {
2581 op.setOperand(insertOp.getSource());
2584 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
2585 b.getI64ArrayAttr(offsetDiffs));
2591 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2601 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2617 class StridedSliceConstantMaskFolder final
2622 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2626 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
2627 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2628 if (!constantMaskOp)
2631 if (extractStridedSliceOp.hasNonUnitStrides())
2634 SmallVector<int64_t, 4> maskDimSizes;
2637 SmallVector<int64_t, 4> sliceOffsets;
2640 SmallVector<int64_t, 4> sliceSizes;
2644 SmallVector<int64_t, 4> sliceMaskDimSizes;
2645 assert(sliceOffsets.size() == maskDimSizes.size());
2646 for (
auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2647 int64_t maskDimSize = std::get<0>(it);
2648 int64_t sliceOffset = std::get<1>(it);
2649 int64_t sliceSize = std::get<2>(it);
2650 int64_t sliceMaskDimSize =
std::max(
2651 static_cast<int64_t>(0),
2652 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2653 sliceMaskDimSizes.push_back(sliceMaskDimSize);
2657 if (llvm::is_contained(sliceMaskDimSizes, 0))
2658 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2663 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2670 class StridedSliceConstantFolder final
2675 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2680 extractStridedSliceOp.getVector().getDefiningOp<arith::ConstantOp>();
2696 class StridedSliceBroadcast final
2707 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
2708 auto dstVecType = op.getType().cast<VectorType>();
2709 unsigned dstRank = dstVecType.getRank();
2710 unsigned rankDiff = dstRank - srcRank;
2714 bool lowerDimMatch =
true;
2715 for (
unsigned i = 0; i < srcRank; i++) {
2716 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2717 lowerDimMatch =
false;
2726 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
2727 if (!lowerDimMatch && !isScalarSrc) {
2728 source = rewriter.
create<ExtractStridedSliceOp>(
2729 op->getLoc(), source,
2740 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
2746 auto splat = op.getVector().getDefiningOp<SplatOp>();
2756 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2760 results.
add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2761 StridedSliceBroadcast, StridedSliceSplat>(context);
2771 ValueRange indices, AffineMapAttr permutationMapAttr,
2772 ArrayAttr inBoundsAttr) {
2774 Value padding = builder.
create<arith::ConstantOp>(
2776 build(builder, result,
vectorType, source, indices, permutationMapAttr,
2777 padding,
Value(), inBoundsAttr);
2784 Optional<ArrayRef<bool>> inBounds) {
2785 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2786 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2789 build(builder, result,
vectorType, source, indices, permutationMapAttr,
2797 Optional<ArrayRef<bool>> inBounds) {
2800 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2801 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2804 build(builder, result,
vectorType, source, indices, permutationMapAttr,
2806 Value(), inBoundsAttr);
2814 Optional<ArrayRef<bool>> inBounds) {
2816 Value padding = builder.
create<arith::ConstantOp>(
2818 build(builder, result,
vectorType, source, indices, padding, inBounds);
2821 template <
typename EmitFun>
2823 EmitFun emitOpError) {
2824 SmallVector<bool, 8> seen(permutationMap.
getNumInputs(),
false);
2825 for (
auto expr : permutationMap.
getResults()) {
2829 if (zero.getValue() != 0) {
2831 "requires a projected permutation_map (at most one dim or the zero " 2832 "constant can appear in each result)");
2837 return emitOpError(
"requires a projected permutation_map (at most one " 2838 "dim or the zero constant can appear in each result)");
2840 if (seen[dim.getPosition()]) {
2842 "requires a permutation_map that is a permutation (found one dim " 2843 "used more than once)");
2845 seen[dim.getPosition()] =
true;
2853 AffineMap permutationMap, ArrayAttr inBounds) {
2854 if (op->hasAttr(
"masked")) {
2855 return op->emitOpError(
"masked attribute has been removed. " 2856 "Use in_bounds instead.");
2859 if (!shapedType.isa<MemRefType, RankedTensorType>())
2860 return op->emitOpError(
2861 "requires source to be a memref or ranked tensor type");
2863 auto elementType = shapedType.getElementType();
2865 if (
auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2867 unsigned sourceVecSize =
2869 vectorElementType.getShape().back();
2870 unsigned resultVecSize =
2872 vectorType.getShape().back();
2873 if (resultVecSize % sourceVecSize != 0)
2874 return op->emitOpError(
2875 "requires the bitwidth of the minor 1-D vector to be an integral " 2876 "multiple of the bitwidth of the minor 1-D vector of the source");
2878 unsigned sourceVecEltRank = vectorElementType.getRank();
2879 unsigned resultVecRank = vectorType.getRank();
2880 if (sourceVecEltRank > resultVecRank)
2881 return op->emitOpError(
2882 "requires source vector element and vector result ranks to match.");
2883 unsigned rankOffset = resultVecRank - sourceVecEltRank;
2886 return op->emitOpError(
"requires a permutation_map with result dims of " 2887 "the same rank as the vector type");
2890 return op->emitOpError(
"does not support masks with vector element type");
2893 unsigned minorSize =
2894 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
2895 unsigned resultVecSize =
2898 return op->emitOpError(
2899 "requires the bitwidth of the minor 1-D vector to be an integral " 2900 "multiple of the bitwidth of the source element type");
2904 return op->emitOpError(
"requires a permutation_map with result dims of " 2905 "the same rank as the vector type");
2907 VectorType expectedMaskType =
2909 if (maskType && expectedMaskType != maskType)
2910 return op->emitOpError(
"expects mask type consistent with permutation " 2916 return op->emitOpError(
"requires permutation_map without symbols");
2918 if (permutationMap.
getNumInputs() != shapedType.getRank())
2919 return op->emitOpError(
"requires a permutation_map with input dims of the " 2920 "same rank as the source type");
2923 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
2924 return op->emitOpError(
"expects the optional in_bounds attr of same rank " 2925 "as permutation_map results: ")
2926 << AffineMapAttr::get(permutationMap)
2927 <<
" vs inBounds of size: " << inBounds.size();
2928 for (
unsigned int i = 0; i < permutationMap.
getNumResults(); ++i)
2931 return op->emitOpError(
"requires broadcast dimensions to be in-bounds");
2938 SmallVector<StringRef, 3> elidedAttrs;
2939 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2940 if (op.permutation_map().isMinorIdentity())
2941 elidedAttrs.push_back(op.getPermutationMapAttrStrName());
2942 bool elideInBounds =
true;
2943 if (
auto inBounds = op.in_bounds()) {
2944 for (
auto attr : *inBounds) {
2945 if (attr.template cast<BoolAttr>().getValue()) {
2946 elideInBounds =
false;
2952 elidedAttrs.push_back(op.getInBoundsAttrStrName());
2957 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
2959 p <<
", " << getMask();
2968 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
2970 SmallVector<Type, 2> types;
2985 if (types.size() != 2)
2986 return parser.
emitError(typesLoc,
"requires two types");
2988 auto shapedType = types[0].dyn_cast<ShapedType>();
2989 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2990 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
2991 VectorType
vectorType = types[1].dyn_cast<VectorType>();
2993 return parser.
emitError(typesLoc,
"requires vector type");
2994 auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
2999 mapAttr = AffineMapAttr::get(permMap);
3008 if (shapedType.getElementType().dyn_cast<VectorType>())
3010 maskInfo.
location,
"does not support masks with vector element type");
3011 auto map = mapAttr.
dyn_cast<AffineMapAttr>().getValue();
3018 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3020 {1, static_cast<int32_t>(indexInfo.size()), 1,
3021 static_cast<int32_t>(hasMask.succeeded())}));
3027 ShapedType shapedType = getShapedType();
3029 VectorType maskType = getMaskType();
3030 auto paddingType = getPadding().getType();
3031 auto permutationMap = getPermutationMap();
3032 auto sourceElementType = shapedType.getElementType();
3034 if (static_cast<int64_t>(
getIndices().size()) != shapedType.getRank())
3035 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
3038 shapedType, vectorType, maskType, permutationMap,
3039 getInBounds() ? *getInBounds() : ArrayAttr())))
3042 if (
auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
3045 if (sourceVectorElementType != paddingType)
3047 "requires source element type and padding type to match.");
3051 if (!VectorType::isValidElementType(paddingType))
3052 return emitOpError(
"requires valid padding vector elemental type");
3055 if (paddingType != sourceElementType)
3057 "requires formal padding and source of the same elemental type");
3061 [&](Twine t) {
return emitOpError(t); });
3070 bool folded =
false;
3072 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
3074 operand.set(castOp.getOperand());
3082 bool folded =
false;
3084 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
3086 operand.set(castOp.getOperand());
3093 template <
typename TransferOp>
3094 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3097 if (op.getShapedType().isDynamicDim(indicesIdx))
3099 Value index = op.getIndices()[indicesIdx];
3104 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3105 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3107 return cstOp.
value() + vectorSize <= sourceSize;
3110 template <
typename TransferOp>
3114 if (op.getTransferRank() == 0)
3117 bool changed =
false;
3118 SmallVector<bool, 4> newInBounds;
3119 newInBounds.reserve(op.getTransferRank());
3120 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
3122 if (op.isDimInBounds(i)) {
3123 newInBounds.push_back(
true);
3129 assert(dimExpr &&
"Broadcast dims must be in-bounds");
3132 newInBounds.push_back(inBounds);
3134 changed |= inBounds;
3140 op->setAttr(TransferOp::getInBoundsAttrStrName(),
3141 b.getBoolArrayAttr(newInBounds));
3155 static Value foldRAW(TransferReadOp readOp) {
3156 if (!readOp.getShapedType().isa<RankedTensorType>())
3158 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
3161 return defWrite.getVector();
3163 cast<VectorTransferOpInterface>(defWrite.getOperation()),
3164 cast<VectorTransferOpInterface>(readOp.getOperation())))
3166 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3171 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
3172 if (
Value vec = foldRAW(*
this))
3188 void TransferReadOp::getEffects(
3191 if (getShapedType().isa<MemRefType>())
3212 struct FoldExtractSliceIntoTransferRead
3220 if (xferOp.getTransferRank() == 0)
3222 if (xferOp.hasOutOfBoundsDim())
3224 if (!xferOp.getPermutationMap().isIdentity())
3226 if (xferOp.getMask())
3228 auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3231 if (!extractOp.hasUnitStride())
3250 int64_t rankReduced =
3251 extractOp.getSourceType().getRank() - extractOp.getType().getRank();
3252 int64_t vectorRank = xferOp.getVectorType().getRank();
3253 RankedTensorType inferredDestTensorType =
3254 tensor::ExtractSliceOp::inferResultType(
3255 extractOp.getSourceType(), extractOp.getMixedOffsets(),
3256 extractOp.getMixedSizes(), extractOp.getMixedStrides());
3257 auto actualDestTensorShape = extractOp.getType().getShape();
3258 if (rankReduced > 0 &&
3259 actualDestTensorShape.take_back(vectorRank) !=
3260 inferredDestTensorType.getShape().take_back(vectorRank))
3263 SmallVector<Value> newIndices;
3266 for (int64_t i = 0; i < rankReduced; ++i) {
3269 rewriter, extractOp.getLoc(), offset));
3273 extractOp.getMixedOffsets()[it.index() + rankReduced];
3274 newIndices.push_back(rewriter.
create<arith::AddIOp>(
3275 xferOp->getLoc(), it.value(),
3279 SmallVector<bool> inBounds(xferOp.getTransferRank(),
true);
3281 xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
3282 xferOp.getPadding(), ArrayRef<bool>{inBounds});
3310 struct TransferReadAfterWriteToBroadcast
3316 if (readOp.hasOutOfBoundsDim() ||
3317 !readOp.getShapedType().isa<RankedTensorType>())
3319 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3323 SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
3325 if (readOp.getIndices() == defWrite.getIndices() &&
3326 readOp.getMask() == defWrite.getMask()) {
3327 SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
3330 if (writeDims == readDims)
3331 vec = defWrite.getVector();
3338 SmallVector<unsigned> permutation;
3342 if (map.getNumResults() == 0)
3346 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
3352 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
3353 SmallVector<int64_t> broadcastShape(destShape.size());
3355 broadcastShape[pos.value()] = destShape[pos.index()];
3356 VectorType broadcastedType = VectorType::get(
3357 broadcastShape, defWrite.getVectorType().getElementType());
3358 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
3359 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
3370 .
add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
3381 AffineMapAttr permutationMapAttr,
3383 ArrayAttr inBoundsAttr) {
3385 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3386 mask, inBoundsAttr);
3392 AffineMapAttr permutationMapAttr,
3393 ArrayAttr inBoundsAttr) {
3394 build(builder, result, vector, dest, indices, permutationMapAttr,
3395 Value(), inBoundsAttr);
3403 Optional<ArrayRef<bool>> inBounds) {
3404 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3405 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3408 build(builder, result, vector, dest, indices, permutationMapAttr,
3409 Value(), inBoundsAttr);
3416 Optional<ArrayRef<bool>> inBounds) {
3420 build(builder, result, vector, dest, indices, permutationMap, inBounds);
3428 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
3429 SmallVector<Type, 2> types;
3441 if (types.size() != 2)
3442 return parser.
emitError(typesLoc,
"requires two types");
3444 VectorType
vectorType = types[0].dyn_cast<VectorType>();
3446 return parser.
emitError(typesLoc,
"requires vector type");
3447 ShapedType shapedType = types[1].dyn_cast<ShapedType>();
3448 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3449 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
3450 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3454 result.
attributes.
set(permutationAttrName, AffineMapAttr::get(permMap));
3461 if (shapedType.getElementType().dyn_cast<VectorType>())
3463 maskInfo.
location,
"does not support masks with vector element type");
3464 auto maskType = VectorType::get(vectorType.getShape(), builder.
getI1Type());
3468 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
3470 {1, 1, static_cast<int32_t>(indexInfo.size()),
3471 static_cast<int32_t>(hasMask.succeeded())}));
3472 return failure(shapedType.isa<RankedTensorType>() &&
3477 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
3479 p <<
", " << getMask();
3486 ShapedType shapedType = getShapedType();
3488 VectorType maskType = getMaskType();
3489 auto permutationMap = getPermutationMap();
3491 if (llvm::size(
getIndices()) != shapedType.getRank())
3492 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
3496 if (hasBroadcastDim())
3497 return emitOpError(
"should not have broadcast dimensions");
3500 shapedType, vectorType, maskType, permutationMap,
3501 getInBounds() ? *getInBounds() : ArrayAttr())))
3505 [&](Twine t) {
return emitOpError(t); });
3525 static LogicalResult foldReadInitWrite(TransferWriteOp write,
3526 ArrayRef<Attribute>,
3529 if (write.getTransferRank() == 0)
3531 auto rankedTensorType =
3532 write.getSource().getType().dyn_cast<RankedTensorType>();
3534 if (!rankedTensorType)
3537 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3541 if (read.getTransferRank() == 0)
3544 if (!read.getPermutationMap().isMinorIdentity() ||
3545 !write.getPermutationMap().isMinorIdentity())
3548 if (read.getTransferRank() != write.getTransferRank())
3551 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
3554 if (read.getSource().getType() != rankedTensorType)
3557 if (read.getVectorType() != write.getVectorType())
3560 if (read.getVectorType().getShape() != rankedTensorType.getShape())
3563 auto isNotConstantZero = [](
Value v) {
3565 return !cstOp || cstOp.
value() != 0;
3567 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
3568 llvm::any_of(write.getIndices(), isNotConstantZero))
3571 results.push_back(read.getSource());
3575 static bool checkSameValueWAR(vector::TransferReadOp read,
3576 vector::TransferWriteOp write) {
3577 return read.getSource() == write.getSource() &&
3578 read.getIndices() == write.getIndices() &&
3579 read.getPermutationMap() == write.getPermutationMap() &&
3580 read.getVectorType() == write.getVectorType() && !read.getMask() &&
3599 if (!write.getSource().getType().isa<RankedTensorType>())
3601 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3605 if (!checkSameValueWAR(read, write))
3607 results.push_back(read.getSource());
3611 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
3613 if (
succeeded(foldReadInitWrite(*
this, operands, results)))
3626 void TransferWriteOp::getEffects(
3629 if (getShapedType().isa<MemRefType>())
3664 if (!writeOp.getShapedType().isa<RankedTensorType>())
3666 vector::TransferWriteOp writeToModify = writeOp;
3669 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3672 writeToModify.getSourceMutable().assign(defWrite.getSource());
3676 cast<VectorTransferOpInterface>(defWrite.getOperation()),
3677 cast<VectorTransferOpInterface>(writeOp.getOperation())))
3681 if (!defWrite->hasOneUse())
3683 writeToModify = defWrite;
3684 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3704 struct FoldInsertSliceIntoTransferWrite
3709 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3711 if (!insertOp.hasUnitStride())
3714 auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
3718 if (xferOp.getTransferRank() == 0)
3721 if (xferOp.hasOutOfBoundsDim())
3723 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
3725 if (xferOp.getMask())
3730 if (!llvm::equal(xferOp.getVectorType().getShape(),
3731 xferOp.getShapedType().getShape()))
3733 if (!xferOp.getPermutationMap().isIdentity())
3752 int64_t rankReduced =
3753 insertOp.getType().getRank() - insertOp.getSourceType().getRank();
3754 int64_t vectorRank = xferOp.getVectorType().getRank();
3755 RankedTensorType inferredSourceTensorType =
3756 tensor::ExtractSliceOp::inferResultType(
3757 insertOp.getType(), insertOp.getMixedOffsets(),
3758 insertOp.getMixedSizes(), insertOp.getMixedStrides());
3759 auto actualSourceTensorShape = insertOp.getSourceType().getShape();
3760 if (rankReduced > 0 &&
3761 actualSourceTensorShape.take_back(vectorRank) !=
3762 inferredSourceTensorType.getShape().take_back(vectorRank))
3766 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
3767 SmallVector<bool> inBounds(xferOp.getTransferRank(),
true);
3769 insertOp.getDest(), indices,
3770 ArrayRef<bool>{inBounds});
3798 struct SwapExtractSliceOfTransferWrite
3803 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3805 if (!insertOp.hasUnitStride())
3808 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3809 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
3811 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
3812 if (!transferOp || !transferOp->hasOneUse())
3817 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
3819 "use-def chain is rank-reducing");
3823 if (!extractOp.hasZeroOffset()) {
3825 "ExtractSliceOp has non-zero offset");
3829 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
3833 "TranferWriteOp has non-zero offset");
3837 for (
const auto &it :
3838 llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
3841 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
3846 assert(transferOp.getVectorType().hasStaticShape() &&
3847 "expected vector to have a static shape");
3848 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
3850 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
3851 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
3853 insertOp,
"TransferWriteOp may not write the full tensor.");
3858 transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
3859 SmallVector<bool> newInBounds;
3860 for (
const auto &en :
enumerate(newResultShape))
3861 newInBounds.push_back(en.value() == vectorShape[en.index()]);
3862 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
3863 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
3864 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
3865 insertOp.getMixedStrides());
3866 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
3867 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
3868 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
3871 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
3881 results.
add<FoldWaw, FoldInsertSliceIntoTransferWrite,
3882 SwapExtractSliceOfTransferWrite>(context);
3890 MemRefType memRefTy) {
3892 return op->
emitOpError(
"most minor memref dim must have unit stride");
3900 if (
failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
3904 Type memElemTy = memRefTy.getElementType();
3905 if (
auto memVecTy = memElemTy.
dyn_cast<VectorType>()) {
3906 if (memVecTy != resVecTy)
3907 return emitOpError(
"base memref and result vector types should match");
3908 memElemTy = memVecTy.getElementType();
3911 if (resVecTy.getElementType() != memElemTy)
3912 return emitOpError(
"base and result element types should match");
3913 if (llvm::size(
getIndices()) != memRefTy.getRank())
3914 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
3932 if (
failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
3936 Type memElemTy = memRefTy.getElementType();
3937 if (
auto memVecTy = memElemTy.
dyn_cast<VectorType>()) {
3938 if (memVecTy != valueVecTy)
3940 "base memref and valueToStore vector types should match");
3941 memElemTy = memVecTy.getElementType();
3944 if (valueVecTy.getElementType() != memElemTy)
3945 return emitOpError(
"base and valueToStore element type should match");
3946 if (llvm::size(
getIndices()) != memRefTy.getRank())
3947 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
3961 VectorType maskVType = getMaskVectorType();
3962 VectorType passVType = getPassThruVectorType();
3966 if (resVType.getElementType() != memType.getElementType())
3967 return emitOpError(
"base and result element type should match");
3968 if (llvm::size(
getIndices()) != memType.getRank())
3969 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
3970 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3971 return emitOpError(
"expected result dim to match mask dim");
3972 if (resVType != passVType)
3973 return emitOpError(
"expected pass_thru of same type as result type");
3986 load, load.getType(), load.getBase(), load.getIndices());
3989 rewriter.
replaceOp(load, load.getPassThru());
3994 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
4001 results.
add<MaskedLoadFolder>(context);
4015 VectorType maskVType = getMaskVectorType();
4019 if (valueVType.getElementType() != memType.getElementType())
4020 return emitOpError(
"base and valueToStore element type should match");
4021 if (llvm::size(
getIndices()) != memType.getRank())
4022 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4023 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4024 return emitOpError(
"expected valueToStore dim to match mask dim");
4037 store, store.getValueToStore(), store.getBase(), store.getIndices());
4045 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
4052 results.
add<MaskedStoreFolder>(context);
4055 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
4065 VectorType indVType = getIndexVectorType();
4066 VectorType maskVType = getMaskVectorType();
4068 ShapedType baseType = getBaseType();
4070 if (!baseType.isa<MemRefType, RankedTensorType>())
4071 return emitOpError(
"requires base to be a memref or ranked tensor type");
4073 if (resVType.getElementType() != baseType.getElementType())
4074 return emitOpError(
"base and result element type should match");
4075 if (llvm::size(
getIndices()) != baseType.getRank())
4076 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
4077 if (resVType.getDimSize(0) != indVType.getDimSize(0))
4078 return emitOpError(
"expected result dim to match indices dim");
4079 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4080 return emitOpError(
"expected result dim to match mask dim");
4081 if (resVType != getPassThruVectorType())
4082 return emitOpError(
"expected pass_thru of same type as result type");
4096 rewriter.
replaceOp(gather, gather.getPassThru());
4101 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
4108 results.
add<GatherFolder>(context);
4116 VectorType indVType = getIndexVectorType();
4117 VectorType maskVType = getMaskVectorType();
4121 if (valueVType.getElementType() != memType.getElementType())
4122 return emitOpError(
"base and valueToStore element type should match");
4123 if (llvm::size(
getIndices()) != memType.getRank())
4124 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4125 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4126 return emitOpError(
"expected valueToStore dim to match indices dim");
4127 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4128 return emitOpError(
"expected valueToStore dim to match mask dim");
4147 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
4154 results.
add<ScatterFolder>(context);
4162 VectorType maskVType = getMaskVectorType();
4163 VectorType passVType = getPassThruVectorType();
4167 if (resVType.getElementType() != memType.getElementType())
4168 return emitOpError(
"base and result element type should match");
4169 if (llvm::size(
getIndices()) != memType.getRank())
4170 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4171 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4172 return emitOpError(
"expected result dim to match mask dim");
4173 if (resVType != passVType)
4174 return emitOpError(
"expected pass_thru of same type as result type");
4187 expand, expand.getType(), expand.getBase(), expand.getIndices());
4190 rewriter.
replaceOp(expand, expand.getPassThru());
4195 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
4202 results.
add<ExpandLoadFolder>(context);
4210 VectorType maskVType = getMaskVectorType();
4214 if (valueVType.getElementType() != memType.getElementType())
4215 return emitOpError(
"base and valueToStore element type should match");
4216 if (llvm::size(
getIndices()) != memType.getRank())
4217 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4218 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4219 return emitOpError(
"expected valueToStore dim to match mask dim");
4224 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
4232 compress, compress.getValueToStore(), compress.getBase(),
4233 compress.getIndices());
4241 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
4248 results.
add<CompressStoreFolder>(context);
4257 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
4258 unsigned rankA = a.size();
4259 unsigned rankB = b.size();
4260 assert(rankA < rankB);
4264 while (i < rankA && j < rankB) {
4265 int64_t dimA = a[i];
4267 while (dimB < dimA && j < rankB)
4275 auto isOne = [](int64_t v) {
return v == 1; };
4276 if (i < rankA && llvm::all_of(a.slice(i), isOne))
4278 if (j < rankB && llvm::all_of(b.slice(j), isOne))
4282 return i == rankA && j == rankB;
4286 VectorType sourceVectorType,
4287 VectorType resultVectorType) {
4289 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
4290 return op->
emitOpError(
"source/result vectors must have same element type");
4291 auto sourceShape = sourceVectorType.getShape();
4292 auto resultShape = resultVectorType.getShape();
4295 int64_t sourceDimProduct = std::accumulate(
4296 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
4297 int64_t resultDimProduct = std::accumulate(
4298 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
4299 if (sourceDimProduct != resultDimProduct)
4300 return op->
emitOpError(
"source/result number of elements must match");
4303 unsigned sourceRank = sourceVectorType.getRank();
4304 unsigned resultRank = resultVectorType.getRank();
4305 if (sourceRank < resultRank) {
4306 if (!isValidShapeCast(sourceShape, resultShape))
4308 }
else if (sourceRank > resultRank) {
4309 if (!isValidShapeCast(resultShape, sourceShape))
4316 auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
4317 auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
4320 if (sourceVectorType && resultVectorType)
4321 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
4326 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
4328 if (getSource().getType() == getResult().getType())
4332 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
4333 if (getResult().getType() == otherOp.getSource().getType())
4334 return otherOp.getSource();
4337 VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
4338 VectorType resultType = getResult().getType().cast<VectorType>();
4339 if (srcType.getRank() < resultType.getRank()) {
4340 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
4342 }
else if (srcType.getRank() > resultType.getRank()) {
4343 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
4349 setOperand(otherOp.getSource());
4354 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4355 if (bcastOp.getSourceType() == getType())
4356 return bcastOp.getSource();
4364 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
4371 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
4390 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
4397 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
4401 auto broadcastSourceVectorType =
4402 broadcastOp.getSourceType().dyn_cast<VectorType>();
4403 auto broadcastSourceShape = broadcastSourceVectorType
4404 ? broadcastSourceVectorType.getShape()
4405 : ArrayRef<int64_t>{};
4406 auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
4409 bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
4410 broadcastSourceShape.size()));
4415 shapeCastOp, shapeCastOp.getResultVectorType(),
4416 broadcastOp.getSource());
4425 results.
add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
4433 auto sourceVectorType = getSourceVectorType();
4434 auto resultVectorType = getResultVectorType();
4436 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
4437 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4438 return emitOpError(
"dimension size mismatch at: ") << i;
4442 auto sourceElementBits =
4444 auto resultElementBits =
4447 if (sourceVectorType.getRank() == 0) {
4448 if (sourceElementBits != resultElementBits)
4449 return emitOpError(
"source/result bitwidth of the 0-D vector element " 4450 "types must be equal");
4451 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
4452 resultElementBits * resultVectorType.getShape().back()) {
4454 "source/result bitwidth of the minor 1-D vectors must be equal");
4460 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
4462 if (getSource().getType() == getResult().getType())
4466 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
4467 if (getResult().getType() == otherOp.getSource().getType())
4468 return otherOp.getSource();
4470 setOperand(otherOp.getSource());
4474 Attribute sourceConstant = operands.front();
4475 if (!sourceConstant)
4478 Type srcElemType = getSourceVectorType().getElementType();
4479 Type dstElemType = getResultVectorType().getElementType();
4482 if (floatPack.isSplat()) {
4483 auto splat = floatPack.getSplatValue<FloatAttr>();
4486 if (srcElemType.
isF16() && dstElemType.
isF32()) {
4487 uint32_t bits =
static_cast<uint32_t
>(
4488 splat.getValue().bitcastToAPInt().getZExtValue());
4490 bits = (bits << 16) | (bits & 0xffff);
4491 APInt intBits(32, bits);
4492 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
4505 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
4506 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
4507 SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
4508 memRefType.getShape().end());
4519 MemRefType memRefType = source.
getType().
cast<MemRefType>();
4521 VectorType::get(extractShape(memRefType),
4524 memRefType.getMemorySpace()));
4529 if (!canonicalType.getLayout().isIdentity())
4530 return emitOpError(
"expects operand to be a memref with identity layout");
4531 if (!getResultMemRefType().getLayout().isIdentity())
4532 return emitOpError(
"expects result to be a memref with identity layout");
4533 if (getResultMemRefType().getMemorySpace() !=
4535 return emitOpError(
"expects result in same memory space");
4538 auto resultType = getResultMemRefType();
4542 "expects result and operand with same underlying scalar type: ")
4544 if (extractShape(sourceType) != extractShape(resultType))
4546 "expects concatenated result and operand shapes to be equal: ")
4556 Value vector, ArrayRef<int64_t> transp) {
4557 VectorType vt = vector.
getType().
cast<VectorType>();
4558 SmallVector<int64_t, 4> transposedShape(vt.getRank());
4559 for (
unsigned i = 0; i < transp.size(); ++i)
4560 transposedShape[i] = vt.getShape()[transp[i]];
4563 result.
addTypes(VectorType::get(transposedShape, vt.getElementType()));
4567 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4571 return attr.reshape(getResultType());
4575 SmallVector<int64_t, 4> transp;
4580 for (int64_t i = 0, e = transp.size(); i < e; i++) {
4590 VectorType resultType = getResultType();
4591 int64_t rank = resultType.getRank();
4592 if (vectorType.getRank() != rank)
4593 return emitOpError(
"vector result rank mismatch: ") << rank;
4595 auto transpAttr = getTransp().getValue();
4596 int64_t size = transpAttr.size();
4598 return emitOpError(
"transposition length mismatch: ") << size;
4599 SmallVector<bool, 8> seen(rank,
false);
4601 int64_t i = ta.value().cast<IntegerAttr>().getInt();
4602 if (i < 0 || i >= rank)
4603 return emitOpError(
"transposition index out of range: ") << i;
4605 return emitOpError(
"duplicate position index: ") << i;
4607 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
4608 return emitOpError(
"dimension size mismatch at: ") << i;
4614 return llvm::to_vector<4>(getResultType().getShape());
4620 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
4624 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4627 auto getPermutation = [](vector::TransposeOp transpose) {
4628 SmallVector<int64_t, 4> permutation;
4629 transpose.getTransp(permutation);
4634 auto composePermutations = [](ArrayRef<int64_t> permutation1,
4635 ArrayRef<int64_t> permutation2) {
4636 SmallVector<int64_t, 4> result;
4637 for (
auto index : permutation2)
4638 result.push_back(permutation1[index]);
4643 vector::TransposeOp parentTransposeOp =
4644 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
4645 if (!parentTransposeOp)
4648 SmallVector<int64_t, 4> permutation = composePermutations(
4649 getPermutation(parentTransposeOp), getPermutation(transposeOp));
4652 transposeOp, transposeOp.getResult().getType(),
4653 parentTransposeOp.getVector(),
4660 struct FoldTransposedScalarBroadcast final
4664 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4666 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
4670 auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
4671 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
4673 transposeOp, transposeOp.getResultType(), bcastOp.getSource());
4688 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
4693 transposeOp, transposeOp.getResultType(), splatOp.getInput());
4700 void vector::TransposeOp::getCanonicalizationPatterns(
4703 .
add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
4716 auto resultType = getResult().getType().cast<VectorType>();
4718 if (resultType.getRank() == 0) {
4719 if (getMaskDimSizes().size() != 1)
4720 return emitError(
"array attr must have length 1 for 0-D vectors");
4721 auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
4722 if (dim != 0 && dim != 1)
4723 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
4728 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
4730 "must specify array attr of size equal vector result rank");
4733 auto resultShape = resultType.getShape();
4734 SmallVector<int64_t, 4> maskDimSizes;
4736 int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
4737 if (attrValue < 0 || attrValue > resultShape[it.index()])
4739 "array attr of size out of bounds of vector result dimension size");
4740 maskDimSizes.push_back(attrValue);
4744 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
4745 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
4746 if (anyZeros && !allZeros)
4747 return emitOpError(
"expected all mask dim sizes to be zeros, " 4748 "as a result of conjunction with zero mask dim");
4754 if (resultType.isScalable() &&
4755 getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
4756 return emitOpError(
"expected mask dim sizes for scalable masks to be 0");
4765 auto vectorType = getResult().getType().cast<VectorType>();
4768 if (getNumOperands() != 1)
4770 "must specify exactly one operand for 0-D create_mask");
4771 }
else if (getNumOperands() !=
4772 getResult().getType().cast<VectorType>().getRank()) {
4774 "must specify an operand for each result vector dimension");
4789 auto isNotDefByConstant = [](
Value operand) {
4790 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
4792 if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
4797 if (
auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
4798 if (vType.isScalable())
4799 for (
auto opDim : createMaskOp.getOperands()) {
4802 intVal.isStrictlyPositive())
4808 SmallVector<int64_t, 4> maskDimSizes;
4809 for (
auto it : llvm::zip(createMaskOp.operands(),
4810 createMaskOp.getType().getShape())) {
4811 auto *defOp = std::get<0>(it).getDefiningOp();
4812 int64_t maxDimSize = std::get<1>(it);
4813 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).
value();
4814 dimSize =
std::min(dimSize, maxDimSize);
4817 maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
4820 maskDimSizes.push_back(dimSize);
4824 createMaskOp, createMaskOp.getResult().getType(),
4834 results.
add<CreateMaskFolder>(context);
4842 VectorType srcType = getSourceType();
4843 VectorType initialType = getInitialValueType();
4845 int64_t srcRank = srcType.getRank();
4846 int64_t reductionDim = getReductionDim();
4847 if (reductionDim >= srcRank)
4848 return emitOpError(
"reduction dimension ")
4849 << reductionDim <<
" has to be less than " << srcRank;
4852 int64_t initialValueRank = initialType.getRank();
4853 if (initialValueRank != srcRank - 1)
4854 return emitOpError(
"initial value rank ")
4855 << initialValueRank <<
" has to be equal to " << srcRank - 1;
4858 ArrayRef<int64_t> srcShape = srcType.getShape();
4859 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
4860 SmallVector<int64_t> expectedShape;
4861 for (
int i = 0; i < srcRank; i++) {
4862 if (i != reductionDim)
4863 expectedShape.push_back(srcShape[i]);
4865 if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
4866 [](std::tuple<int64_t, int64_t> s) {
4867 return std::get<0>(s) != std::get<1>(s);
4869 return emitOpError(
"incompatible input/initial value shapes");
4873 Type eltType = getDestType().getElementType();
4875 return emitOpError(
"unsupported reduction type ")
4876 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
4885 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
4886 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
4887 StridedSliceConstantMaskFolder, TransposeFolder>(
4895 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
4896 auto constOperand = operands.front();
4897 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
4909 p <<
"(" << getLaneid() <<
")";
4911 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
4912 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
4913 p <<
"[" << warpSizeAttr.cast<IntegerAttr>().getInt() <<
"]";
4915 if (!getArgs().empty())
4916 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
4917 if (!getResults().empty())
4918 p <<
" -> (" << getResults().getTypes() <<
')';
4922 !getResults().empty());
4952 llvm::SMLoc inputsOperandsLoc;
4953 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
4954 SmallVector<Type> inputTypes;
4964 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
4975 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.
location);
4983 void WarpExecuteOnLane0Op::getSuccessorRegions(
4998 build(builder, result, resultTypes, laneId, warpSize,
5011 assert(args.size() == blockArgTypes.size());
5015 for (
auto it : llvm::zip(blockArgTypes, args))
5016 block->
addArgument(std::get<0>(it), std::get<1>(it).getLoc());
5024 if (expanded == distributed)
5026 auto expandedVecType = expanded.
dyn_cast<VectorType>();
5027 auto distributedVecType = distributed.
dyn_cast<VectorType>();
5028 if (!expandedVecType || !distributedVecType)
5029 return op->
emitOpError(
"expected vector type for distributed operands.");
5030 if (expandedVecType.getRank() != distributedVecType.getRank() ||
5031 expandedVecType.getElementType() != distributedVecType.getElementType())
5033 "expected distributed vectors to have same rank and element type.");
5034 bool foundDistributedDim =
false;
5035 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
5036 if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
5038 if (expandedVecType.getDimSize(i) ==
5039 distributedVecType.getDimSize(i) * warpSize) {
5040 if (foundDistributedDim)
5042 <<
"expected only one dimension to be distributed from " 5043 << expandedVecType <<
" to " << distributedVecType;
5044 foundDistributedDim =
true;
5047 return op->
emitOpError() <<
"incompatible distribution dimensions from " 5048 << expandedVecType <<
" to " << distributedVecType;
5054 if (getArgs().size() != getWarpRegion().getNumArguments())
5056 "expected same number op arguments and block arguments.");
5058 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
5059 if (yield.getNumOperands() != getNumResults())
5061 "expected same number of yield operands and return values.");
5062 int64_t warpSize = getWarpSize();
5063 for (
auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
5064 if (
failed(verifyDistributedType(std::get<0>(it).getType(),
5065 std::get<1>(it).getType(), warpSize,
5069 for (
auto it : llvm::zip(yield.getOperands(), getResults())) {
5070 if (
failed(verifyDistributedType(std::get<0>(it).getType(),
5071 std::get<1>(it).getType(), warpSize,
5078 bool WarpExecuteOnLane0Op::areTypesCompatible(
Type lhs,
Type rhs) {
5080 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
5088 case CombiningKind::ADD:
5093 llvm_unreachable(
"invalid value types for ADD reduction");
5094 case CombiningKind::AND:
5095 assert(t1.
isIntOrIndex() && t2.isIntOrIndex() &&
"expected int values");
5097 case CombiningKind::MAXF:
5099 "expected float values");
5101 case CombiningKind::MINF:
5103 "expected float values");
5105 case CombiningKind::MAXSI:
5106 assert(t1.
isIntOrIndex() && t2.isIntOrIndex() &&
"expected int values");
5108 case CombiningKind::MINSI:
5109 assert(t1.
isIntOrIndex() && t2.isIntOrIndex() &&
"expected int values");
5111 case CombiningKind::MAXUI:
5112 assert(t1.
isIntOrIndex() && t2.isIntOrIndex() &&
"expected int values");
5114 case CombiningKind::MINUI:
5115 assert(t1.
isIntOrIndex() && t2.isIntOrIndex() &&
"expected int values");
5117 case CombiningKind::MUL:
5122 llvm_unreachable(
"invalid value types for MUL reduction");
5123 case CombiningKind::OR:
5124 assert(t1.
isIntOrIndex() && t2.isIntOrIndex() &&
"expected int values");
5126 case CombiningKind::XOR:
5127 assert(t1.
isIntOrIndex() && t2.isIntOrIndex() &&
"expected int values");
5130 llvm_unreachable(
"unknown CombiningKind");
5137 #define GET_OP_CLASSES 5138 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
virtual ParseResult parseLParen()=0
Parse a ( token.
An attribute that represents a reference to a dense float vector or tensor object.
MLIRContext * getContext() const
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
U dyn_cast_or_null() const
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Operation is a basic unit of execution within MLIR.
unsigned getNumSymbols() const
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Attribute getZeroAttr(Type type)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
CombiningKind getKind() const
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Value getOperand(unsigned idx)
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
This is a utility allocator used to allocate memory for instances of derived types.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, StringRef targetIteratorTypeName, MLIRContext *context)
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
SmallVector< int64_t, 4 > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper that returns a subset of arrayAttr as a vector of int64_t.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
This is the representation of an operand reference.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
constexpr StringRef getIteratorTypesAttrName()
Attribute name for the StrArrayAttr which encodes the type of a structured op's iterators.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
virtual ParseResult parseCustomAttributeWithFallback(Attribute &result, Type type, function_ref< ParseResult(Attribute &result, Type type)> parseAttribute)=0
Parse a custom attribute with the provided callback, unless the next token is #, in which case the ge...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
An integer constant appearing in affine expression.
static ArrayRef< int64_t > vectorShape(Type type)
virtual ParseResult parseComma()=0
Parse a , token.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
static constexpr const bool value
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
AffineExpr getResult(unsigned idx) const
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
unsigned getNumInputs() const
static DefaultResource * get()
Returns a unique instance for the given effect class.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
MutableArrayRef< OpOperand > getOpOperands()
virtual ParseResult parseLSquare()=0
Parse a [ token.
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector...
T * allocate()
Allocate an instance of the provided type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
An attribute that represents a reference to a dense vector or tensor object.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseGreater()=0
Parse a '>' token.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
void addOperands(ValueRange newOperands)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
static LogicalResult foldMemRefCast(Operation *op)
This is a common class used for patterns of the form someop(memrefcast) -> someop It folds the source...
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
IntegerAttr getI64IntegerAttr(int64_t value)
BitmaskEnumStorage(KeyTy val)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, unsigned memorySpace=0)
Return a MemRefType to which the type of the given value can be bufferized.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
IntegerType getIntegerType(unsigned width)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
virtual ParseResult parseRParen()=0
Parse a ) token.
Base type for affine expression.
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
static LogicalResult foldTensorCast(Operation *op)
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, AffineMap permutationMap, ArrayAttr inBounds)
This class provides an abstraction over the various different ranges of value types.
void addTypes(ArrayRef< Type > newTypes)
virtual ParseResult parseLess()=0
Parse a '<' token.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width)
void print(AsmPrinter &p) const
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
unsigned getNumResults() const
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
MaskFormat
Helper enum to classify mask value.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns)
Collect a set of vector-to-vector canonicalization patterns.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
ArrayRef< AffineExpr > getResults() const
This class represents a specific instance of an effect.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
virtual ParseResult parseRSquare()=0
Parse a ] token.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
LogicalResult emitOptionalError(Optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Location getLoc() const
Return the location of this value.
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This base class exposes generic asm parser hooks, usable across the various derived parsers...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
unsigned getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
An attribute that specifies the combining function for vector.contract, and vector.reduction.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
MLIRContext * getContext() const
Get the context held by this operation state.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This is a builder type that keeps local references to arguments.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
constexpr StringRef getIndexingMapsAttrName()
Attribute name for the AffineArrayAttr which encodes the relationship between a structured op iterato...
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Type getType() const
Return the type of this value.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static constexpr const CombiningKind combiningKindsList[]
Do not split vector transfer operations.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
A dimensional identifier appearing in an affine expression.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Specialization of arith.constant op that returns an integer of index type.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
constexpr StringRef getStridesAttrName()
Attribute name for the StrArrayAttr which encodes the value of strides.
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
virtual ParseResult parseType(Type &result)=0
Parse a type.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
static Value foldExtractStridedOpFromInsertChain(ExtractOp op)
Fold extract_op fed from a chain of insertStridedSlice ops.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineMap calculateImplicitMap(MapOp op)
MLIRContext is the top-level object for a collection of MLIR operations.
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class represents an operand of an operation.
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
This class provides an abstraction over the different types of ranges over Regions.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Same behavior as isDisjointTransferSet but doesn't require the operations to have the same tensor/mem...
Base storage class appearing in an attribute.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
This base class exposes generic asm printer hooks, usable across the various derived printers...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
static CombiningKindAttr get(CombiningKind kind, MLIRContext *context)
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers...
Builder & setElementType(Type newElementType)
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
MLIRContext * getContext() const
This class represents success/failure for parsing-like operations that find it important to chain tog...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static Attribute parse(AsmParser &parser, Type type)
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
Return a fused vector::ContractionOp which represents a patterns such as:
VectorType transferMaskType(VectorType vecType, AffineMap map)
Given the vector type and the permutation map of a vector transfer op, compute the expected mask type...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
bool operator==(const KeyTy &key) const
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
MLIRContext * getContext() const
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Square brackets surrounding zero or more operands.
An attribute that represents a reference to a dense integer vector or tensor object.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
The main mechanism for performing data layout queries.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
static MaskFormat get1DMaskFormat(Value mask)
Helper method to classify a 1-D mask value.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)