39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
91 ArrayAttr masks = m.getMaskDimSizes();
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
110 auto maskOperands = m.getOperands();
111 for (
Value operand : maskOperands) {
112 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
127 builder.
create<vector::YieldOp>(loc);
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(elementType);
155 VectorType vectorType) {
156 int64_t elementVectorRank = 0;
157 VectorType elementVectorType =
158 llvm::dyn_cast<VectorType>(shapedType.getElementType());
159 if (elementVectorType)
160 elementVectorRank += elementVectorType.getRank();
163 if (shapedType.getRank() == 0 &&
169 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
170 shapedType.getContext());
177 vector::TransferReadOp read) {
178 auto readMask = read.getMask();
179 auto writeMask = write.getMask();
185 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
186 if (!couldBeSameSplat)
191 m_Constant<DenseElementsAttr>(&splatAttr)) ||
203 vector::TransferReadOp read) {
204 return !defWrite.hasOutOfBoundsDim() &&
205 defWrite.getIndices() == read.getIndices() &&
206 defWrite.getVectorType() == read.getVectorType() &&
207 defWrite.getPermutationMap() == read.getPermutationMap() &&
208 ((!defWrite.getMask() && !read.getMask()) ||
213 vector::TransferWriteOp priorWrite) {
214 return priorWrite.getIndices() == write.getIndices() &&
215 priorWrite.getMask() == write.getMask() &&
216 priorWrite.getVectorType() == write.getVectorType() &&
217 priorWrite.getPermutationMap() == write.getPermutationMap();
221 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
222 bool testDynamicValueUsingBounds) {
224 if (transferA.getVectorType() != transferB.getVectorType())
226 unsigned rankOffset = transferA.getLeadingShapedRank();
227 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
228 Value indexA = transferA.getIndices()[i];
229 Value indexB = transferB.getIndices()[i];
233 if (i < rankOffset) {
236 if (cstIndexA.has_value() && cstIndexB.has_value()) {
237 if (*cstIndexA != *cstIndexB)
241 if (testDynamicValueUsingBounds) {
244 FailureOr<uint64_t> delta =
246 if (succeeded(delta) && *delta != 0)
249 FailureOr<bool> testEqual =
251 if (succeeded(testEqual) && !testEqual.value())
257 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
258 if (cstIndexA.has_value() && cstIndexB.has_value()) {
259 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
260 if (distance >= vectorDim)
264 if (testDynamicValueUsingBounds) {
267 FailureOr<int64_t> delta =
269 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
272 FailureOr<int64_t> computeDelta =
274 if (succeeded(computeDelta)) {
275 if (
std::abs(computeDelta.value()) >= vectorDim)
285 VectorTransferOpInterface transferB,
286 bool testDynamicValueUsingBounds) {
287 if (transferA.getSource() != transferB.getSource())
290 testDynamicValueUsingBounds);
300 for (
auto [posInDim, dimSize, offsetInDim] :
301 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
303 if (posInDim < dimSize + offsetInDim)
307 posInDim = offsetInDim;
317 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
319 assert(constOp &&
"Unexpected non-constant index");
320 return constOp.value();
330 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
331 assert(foldResult.is<
Attribute>() &&
"Unexpected non-constant index");
332 return cast<IntegerAttr>(foldResult.get<
Attribute>()).getInt();
342 llvm::transform(foldResults, std::back_inserter(values),
344 if (
auto attr = foldResult.dyn_cast<
Attribute>())
347 loc, cast<IntegerAttr>(attr).getInt())
350 return foldResult.get<
Value>();
398 void VectorDialect::initialize() {
400 #define GET_ATTRDEF_LIST
401 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
406 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
409 addInterfaces<VectorInlinerInterface>();
411 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
412 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
414 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
416 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
417 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
425 return arith::ConstantOp::materialize(builder, value, type, loc);
441 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
444 CombiningKind kind) {
448 reductionDims.push_back(en.index());
449 build(builder, result, kind, source, acc,
453 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
455 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
460 std::optional<SmallVector<int64_t, 4>>
461 MultiDimReductionOp::getShapeForUnroll() {
462 return llvm::to_vector<4>(getSourceVectorType().
getShape());
468 Type inferredReturnType;
469 auto sourceScalableDims = getSourceVectorType().getScalableDims();
471 if (!llvm::any_of(getReductionDims().getValue(), [&](
Attribute attr) {
472 return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
474 targetShape.push_back(it.value());
475 scalableDims.push_back(sourceScalableDims[it.index()]);
478 if (targetShape.empty())
479 inferredReturnType = getSourceVectorType().getElementType();
482 targetShape, getSourceVectorType().
getElementType(), scalableDims);
483 if (
getType() != inferredReturnType)
484 return emitOpError() <<
"destination type " <<
getType()
485 <<
" is incompatible with source type "
486 << getSourceVectorType();
492 Type MultiDimReductionOp::getExpectedMaskType() {
493 auto vecType = getSourceVectorType();
496 vecType.getScalableDims());
505 struct ElideUnitDimsInMultiDimReduction
509 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
512 for (
const auto &dim :
enumerate(shape)) {
513 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
521 if (reductionOp.isMasked()) {
523 rootOp = reductionOp.getMaskingOp();
524 mask = reductionOp.getMaskingOp().getMask();
526 rootOp = reductionOp;
529 Location loc = reductionOp.getLoc();
530 Value acc = reductionOp.getAcc();
532 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
534 VectorType newMaskType =
536 dstVecType.getScalableDims());
537 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
539 cast = rewriter.
create<vector::ShapeCastOp>(
540 loc, reductionOp.getDestType(), reductionOp.getSource());
546 mask = rewriter.
create<vector::ExtractOp>(loc, mask, zeroIdx);
547 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource(),
553 cast,
nullptr, mask);
560 void MultiDimReductionOp::getCanonicalizationPatterns(
562 results.
add<ElideUnitDimsInMultiDimReduction>(context);
570 CombiningKind kind,
Value vector,
571 arith::FastMathFlags fastMathFlags) {
572 build(builder, result, kind, vector,
Value(), fastMathFlags);
577 arith::FastMathFlags fastMathFlags) {
578 build(builder, result,
579 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
585 int64_t rank = getSourceVectorType().getRank();
587 return emitOpError(
"unsupported reduction rank: ") << rank;
590 Type eltType = getDest().getType();
592 return emitOpError(
"unsupported reduction type '")
593 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
602 Type ReductionOp::getExpectedMaskType() {
603 auto vecType = getSourceVectorType();
606 vecType.getScalableDims());
613 case arith::AtomicRMWKind::addf:
614 case arith::AtomicRMWKind::addi:
615 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
616 CombiningKind::ADD, vector);
617 case arith::AtomicRMWKind::mulf:
618 case arith::AtomicRMWKind::muli:
619 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
620 CombiningKind::MUL, vector);
621 case arith::AtomicRMWKind::minimumf:
622 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
623 CombiningKind::MINIMUMF, vector);
624 case arith::AtomicRMWKind::mins:
625 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
626 CombiningKind::MINSI, vector);
627 case arith::AtomicRMWKind::minu:
628 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
630 case arith::AtomicRMWKind::maximumf:
631 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
632 CombiningKind::MAXIMUMF, vector);
633 case arith::AtomicRMWKind::maxs:
634 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
635 CombiningKind::MAXSI, vector);
636 case arith::AtomicRMWKind::maxu:
637 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
638 CombiningKind::MAXUI, vector);
639 case arith::AtomicRMWKind::andi:
640 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
641 CombiningKind::AND, vector);
642 case arith::AtomicRMWKind::ori:
643 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
644 CombiningKind::OR, vector);
653 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
654 return llvm::to_vector<4>(getSourceVectorType().
getShape());
661 LogicalResult matchAndRewrite(ReductionOp reductionOp,
666 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
669 if (maskableOp.isMasked()) {
671 rootOp = maskableOp.getMaskingOp();
672 mask = maskableOp.getMaskingOp().getMask();
674 rootOp = reductionOp;
677 auto vectorType = reductionOp.getSourceVectorType();
678 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
681 Location loc = reductionOp.getLoc();
683 if (vectorType.getRank() == 0) {
685 mask = rewriter.
create<ExtractElementOp>(loc, mask);
686 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
689 mask = rewriter.
create<ExtractOp>(loc, mask, 0);
690 result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector(), 0);
693 if (
Value acc = reductionOp.getAcc())
696 reductionOp.getFastmathAttr(), mask);
706 results.
add<ElideSingleElementReduction>(context);
720 getIndexingMapsAttrName(result.
name),
724 getIteratorTypesAttrName(result.
name),
727 return IteratorTypeAttr::get(builder.getContext(), t);
733 ArrayAttr indexingMaps,
734 ArrayAttr iteratorTypes) {
735 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
736 ContractionOp::getDefaultKind());
741 ArrayAttr indexingMaps,
742 ArrayAttr iteratorTypes, CombiningKind kind) {
759 DictionaryAttr dictAttr;
774 dictAttr.getValue().end());
780 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
785 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
786 auto maybeIteratorType = symbolizeIteratorType(s);
787 if (!maybeIteratorType.has_value())
788 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
790 iteratorTypeAttrs.push_back(
798 getKindAttrName(result.
name),
800 ContractionOp::getDefaultKind()));
802 if (masksInfo.empty())
804 if (masksInfo.size() != 2)
806 "expected zero or exactly 2 vector mask operands");
807 auto lhsType = llvm::cast<VectorType>(types[0]);
808 auto rhsType = llvm::cast<VectorType>(types[1]);
810 std::array<VectorType, 2> maskTypes = {
820 auto attrNames = getTraitAttrNames();
822 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
824 for (
auto attr : (*this)->getAttrs()) {
825 if (attr.getName() == getIteratorTypesAttrName()) {
827 llvm::cast<ArrayAttr>(attr.getValue())
828 .getAsValueRange<IteratorTypeAttr, IteratorType>();
834 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
838 attrs.emplace_back(getIteratorTypesAttrName(),
840 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
841 attrs.push_back(attr);
845 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
846 p << getRhs() <<
", " << getAcc();
849 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
854 const std::vector<std::pair<int64_t, int64_t>> &map) {
855 for (
auto &dimPair : map) {
856 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
857 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
858 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
865 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
867 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
868 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
871 for (
auto &dimPair : contractingDimMap) {
872 lhsContractingDimSet.insert(dimPair.first);
873 rhsContractingDimSet.insert(dimPair.second);
876 for (
auto &dimPair : batchDimMap)
877 rhsBatchDimSet.insert(dimPair.second);
881 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
882 if (lhsContractingDimSet.count(i) > 0)
884 expectedResultDims.push_back(lhsType.getDimSize(i));
888 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
889 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
891 expectedResultDims.push_back(rhsType.getDimSize(i));
895 if (expectedResultDims.empty()) {
897 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
898 return op.
emitOpError(
"invalid accumulator/result vector shape");
901 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
902 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
903 if (!resVectorType || !accVectorType)
904 return op.
emitOpError(
"invalid accumulator/result vector shape");
910 AffineMap lhsMap = op.getIndexingMapsArray()[0];
911 AffineMap rhsMap = op.getIndexingMapsArray()[1];
914 "expected all dimensions to be either a LHS or a RHS dimension");
917 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
918 VectorType v = pair.first;
919 auto map = pair.second;
920 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
921 unsigned pos = map.getDimPosition(idx);
926 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
927 return op.
emitOpError(
"expected all dimensions to get an extent as "
928 "either a LHS or a RHS dimension");
930 AffineMap resMap = op.getIndexingMapsArray()[2];
936 llvm::IsaPred<AffineConstantExpr>) &&
937 "expected constant extent along all dimensions.");
939 auto expectedShape = llvm::to_vector<4>(
941 return cast<AffineConstantExpr>(e).getValue();
945 resVectorType.getScalableDims());
946 if (resVectorType != expected || accVectorType != expected)
948 "invalid accumulator/result vector shape, expected: ")
955 VectorType lhsType = getLhsType();
956 VectorType rhsType = getRhsType();
957 Type accType = getAccType();
958 Type resType = getResultType();
960 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
961 if (!lhsType.getElementType().isSignlessInteger())
962 return emitOpError(
"only supports signless integer types");
966 if (getIndexingMapsArray().size() != 3)
967 return emitOpError(
"expected an indexing map for each vector operand");
972 unsigned numIterators = getIteratorTypes().getValue().size();
974 auto index = it.index();
975 auto map = it.value();
976 if (map.getNumSymbols() != 0)
977 return emitOpError(
"expected indexing map ")
978 << index <<
" to have no symbols";
979 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
980 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
983 if (map.getNumDims() != numIterators)
984 return emitOpError(
"expected indexing map ")
985 << index <<
" to have " << numIterators <<
" number of inputs";
986 if (map.getNumResults() != rank)
987 return emitOpError(
"expected indexing map ")
988 << index <<
" to have " << rank <<
" number of outputs";
989 if (!map.isProjectedPermutation())
990 return emitOpError(
"expected indexing map ")
991 << index <<
" to be a projected permutation of its inputs";
994 auto contractingDimMap = getContractingDimMap();
995 auto batchDimMap = getBatchDimMap();
998 if (contractingDimMap.empty())
999 return emitOpError(
"expected at least one contracting dimension pair");
1002 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1003 return emitOpError(
"invalid contracting dimension map");
1007 return emitOpError(
"invalid batch dimension map");
1011 contractingDimMap, batchDimMap)))
1015 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1016 auto elementType = vectorType ? vectorType.getElementType() : resType;
1018 return emitOpError(
"unsupported contraction type");
1027 Type ContractionOp::getExpectedMaskType() {
1028 auto indexingMaps = this->getIndexingMapsArray();
1031 VectorType lhsType = this->getLhsType();
1032 VectorType rhsType = this->getRhsType();
1034 unsigned numVecDims = lhsIdxMap.
getNumDims();
1043 lhsType.getScalableDims()[dimIdx];
1048 rhsType.getScalableDims()[dimIdx];
1051 assert(!ShapedType::isDynamicShape(maskShape) &&
1052 "Mask shape couldn't be computed");
1056 maskShapeScalableDims);
1061 getIteratorTypesAttrName(), getKindAttrName()};
1071 static std::vector<std::pair<int64_t, int64_t>>
1073 IteratorType targetIteratorType,
MLIRContext *context) {
1074 std::vector<std::pair<int64_t, int64_t>> dimMap;
1076 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1077 if (iteratorType != targetIteratorType)
1083 if (lhsDim >= 0 && rhsDim >= 0)
1084 dimMap.emplace_back(lhsDim, rhsDim);
1089 void ContractionOp::getIterationBounds(
1091 auto lhsShape = getLhsType().getShape();
1092 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1098 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1099 if (iteratorType == IteratorType::reduction) {
1101 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1102 assert(lhsDimIndex >= 0);
1103 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1107 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1108 assert(resDimIndex >= 0);
1109 assert(resVectorType !=
nullptr);
1110 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1114 void ContractionOp::getIterationIndexMap(
1116 unsigned numMaps = getIndexingMapsArray().size();
1117 iterationIndexMap.resize(numMaps);
1119 auto index = it.index();
1120 auto map = it.value();
1121 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1122 auto dim = cast<AffineDimExpr>(map.getResult(i));
1123 iterationIndexMap[index][dim.getPosition()] = i;
1128 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1130 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1134 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1136 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1140 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1142 getIterationBounds(shape);
1164 template <
typename AddOpType>
1170 auto canonicalize = [&](
Value maybeContraction,
1171 Value otherOperand) -> vector::ContractionOp {
1172 vector::ContractionOp contractionOp =
1173 dyn_cast_or_null<vector::ContractionOp>(
1176 return vector::ContractionOp();
1177 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1178 contractionOp.getAcc().getDefiningOp())) {
1179 if (maybeZero.getValue() ==
1180 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1182 bvm.
map(contractionOp.getAcc(), otherOperand);
1183 auto newContraction =
1184 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1185 rewriter.
replaceOp(addOp, newContraction.getResult());
1186 return newContraction;
1189 return vector::ContractionOp();
1192 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1193 vector::ContractionOp
contract = canonicalize(a, b);
1195 return contract ? success() : failure();
1212 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1216 VectorType vectorType = getSourceVectorType();
1217 if (vectorType.getRank() == 0) {
1219 return emitOpError(
"expected position to be empty with 0-D vector");
1222 if (vectorType.getRank() != 1)
1223 return emitOpError(
"unexpected >1 vector rank");
1225 return emitOpError(
"expected position for 1-D vector");
1229 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1231 if (!adaptor.getPosition())
1235 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1236 return splat.getInput();
1239 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1243 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1244 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1248 auto srcElements = src.getValues<
Attribute>();
1250 uint64_t posIdx = pos.getInt();
1251 if (posIdx >= srcElements.size())
1254 return srcElements[posIdx];
1262 Value source, int64_t position) {
1282 build(builder, result, source, dynamicPos,
1287 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1288 ExtractOp::Adaptor adaptor,
1290 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1291 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1292 vectorType.getRank()) {
1293 inferredReturnTypes.push_back(vectorType.getElementType());
1295 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1296 vectorType.getRank());
1298 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1299 vectorType.getScalableDims().drop_front(n)));
1307 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1308 return vectorType && vectorType.getShape().equals({1}) &&
1309 vectorType.getElementType() == r.front();
1311 if (l.size() == 1 && r.size() == 1 &&
1312 (isCompatible(l, r) || isCompatible(r, l)))
1319 auto dynamicMarkersCount =
1320 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1321 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1323 "mismatch between dynamic and static positions (kDynamic marker but no "
1324 "corresponding dynamic position) -- this can only happen due to an "
1325 "incorrect fold/rewrite");
1326 auto position = getMixedPosition();
1327 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1329 "expected position attribute of rank no greater than vector rank");
1332 int64_t constIdx = cast<IntegerAttr>(pos.get<
Attribute>()).getInt();
1333 if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1334 return emitOpError(
"expected position attribute #")
1336 <<
" to be a non-negative integer smaller than the "
1337 "corresponding vector dimension";
1344 template <
typename IntType>
1346 return llvm::to_vector<4>(llvm::map_range(
1347 arrayAttr.getAsRange<IntegerAttr>(),
1348 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1354 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1358 if (extractOp.hasDynamicPosition())
1362 ExtractOp currentOp = extractOp;
1364 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1365 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1368 if (currentOp.hasDynamicPosition())
1371 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1373 extractOp.setOperand(0, currentOp.getVector());
1376 std::reverse(globalPosition.begin(), globalPosition.end());
1377 extractOp.setStaticPosition(globalPosition);
1389 class ExtractFromInsertTransposeChainState {
1391 ExtractFromInsertTransposeChainState(ExtractOp e);
1400 template <
typename ContainerA,
typename ContainerB>
1401 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1402 return a.size() <= b.size() &&
1403 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1410 template <
typename ContainerA,
typename ContainerB>
1411 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1412 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1413 if (elemA < 0 || elemB < 0)
1428 void updateStateForNextIteration(
Value v) {
1435 LogicalResult handleTransposeOp();
1438 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1453 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1458 Value tryToFoldExtractOpInPlace(
Value source);
1460 ExtractOp extractOp;
1462 int64_t extractedRank;
1464 InsertOp nextInsertOp;
1465 TransposeOp nextTransposeOp;
1480 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1482 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1483 extractedRank(extractOp.getNumIndices()) {
1484 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1485 sentinels.reserve(vectorRank - extractedRank);
1486 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1487 sentinels.push_back(-(i + 1));
1489 extractOp.getStaticPosition().end());
1495 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1497 if (extractOp.hasDynamicPosition())
1500 if (!nextTransposeOp)
1503 nextTransposeOp.getPermutation(), extractOp.getContext()));
1510 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1513 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1520 res = nextInsertOp.getSource();
1522 return success(canFold());
1529 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1531 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1544 res = nextInsertOp.getSource();
1552 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1555 if (extractOp.hasDynamicPosition())
1559 bool nothingToFold = (source == extractOp.getVector());
1560 if (nothingToFold || !canFold())
1565 extractOp.setStaticPosition(
1567 extractOp.getVectorMutable().assign(source);
1568 return extractOp.getResult();
1572 Value ExtractFromInsertTransposeChainState::fold() {
1574 if (extractOp.hasDynamicPosition())
1577 Value valueToExtractFrom = extractOp.getVector();
1578 updateStateForNextIteration(valueToExtractFrom);
1579 while (nextInsertOp || nextTransposeOp) {
1582 if (succeeded(handleTransposeOp())) {
1583 valueToExtractFrom = nextTransposeOp.getVector();
1584 updateStateForNextIteration(valueToExtractFrom);
1590 if (succeeded(handleInsertOpWithMatchingPos(result)))
1595 if (succeeded(handleInsertOpWithPrefixPos(result)))
1596 return tryToFoldExtractOpInPlace(result);
1606 valueToExtractFrom = nextInsertOp.getDest();
1607 updateStateForNextIteration(valueToExtractFrom);
1610 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1615 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1616 auto vecType = dyn_cast<VectorType>(type);
1617 return vecType && vecType.getRank() == 0;
1627 if (extractOp.hasDynamicPosition())
1630 Operation *defOp = extractOp.getVector().getDefiningOp();
1631 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1635 if (extractOp.getType() == source.
getType())
1637 auto getRank = [](
Type type) {
1638 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1643 unsigned broadcastSrcRank = getRank(source.
getType());
1644 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1647 unsigned extractResultRank = getRank(extractOp.getType());
1648 if (extractResultRank >= broadcastSrcRank)
1651 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1652 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1653 if (extractVecType && broadcastVecType &&
1654 extractVecType.getShape() !=
1655 broadcastVecType.getShape().take_back(extractResultRank))
1658 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1659 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1665 broadcastOp.computeBroadcastedUnitDims();
1667 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1668 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1669 if (broadcastedUnitDims.contains(i))
1673 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1674 extractPos.erase(extractPos.begin(),
1675 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1678 extractOp.setOperand(0, source);
1679 extractOp.setStaticPosition(extractPos);
1680 return extractOp.getResult();
1686 if (extractOp.hasDynamicPosition())
1689 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1699 auto getDimReverse = [](VectorType type, int64_t n) {
1700 return type.getShape().take_back(n + 1).front();
1702 int64_t destinationRank =
1703 llvm::isa<VectorType>(extractOp.getType())
1704 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1706 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1708 if (destinationRank > 0) {
1709 auto destinationType =
1710 llvm::cast<VectorType>(extractOp.getResult().getType());
1711 for (int64_t i = 0; i < destinationRank; i++) {
1715 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1716 getDimReverse(destinationType, i))
1723 std::reverse(extractedPos.begin(), extractedPos.end());
1726 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1727 strides.push_back(stride);
1729 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1732 int64_t position =
linearize(extractedPos, strides);
1736 int64_t numDimension =
1737 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1739 for (int64_t i = 0; i < numDimension; i++) {
1740 newStrides.push_back(stride);
1742 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1744 std::reverse(newStrides.begin(), newStrides.end());
1748 extractOp.setStaticPosition(newPosition);
1749 extractOp.setOperand(0, shapeCastOp.getSource());
1750 return extractOp.getResult();
1756 if (extractOp.hasDynamicPosition())
1759 auto extractStridedSliceOp =
1760 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1761 if (!extractStridedSliceOp)
1770 if (extractStridedSliceOp.hasNonUnitStrides())
1775 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1776 while (!sliceOffsets.empty()) {
1777 size_t lastOffset = sliceOffsets.size() - 1;
1778 if (sliceOffsets.back() != 0 ||
1779 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1780 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1782 sliceOffsets.pop_back();
1784 unsigned destinationRank = 0;
1785 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1786 destinationRank = vecType.getRank();
1789 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1790 sliceOffsets.size())
1794 assert(extractedPos.size() >= sliceOffsets.size());
1795 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1796 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1797 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1801 extractOp.setStaticPosition(extractedPos);
1802 return extractOp.getResult();
1808 if (extractOp.hasDynamicPosition())
1811 int64_t destinationRank =
1812 llvm::isa<VectorType>(extractOp.getType())
1813 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1815 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1825 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1826 insertOp.getSourceVectorType().getRank();
1827 if (destinationRank > insertOp.getSourceVectorType().getRank())
1829 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1832 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1833 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1836 bool disjoint =
false;
1838 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1839 int64_t start = insertOffsets[dim];
1841 (dim < insertRankDiff)
1843 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1844 int64_t end = start + size;
1845 int64_t offset = extractOffsets[dim];
1847 if (start <= offset && offset < end) {
1848 if (dim >= insertRankDiff)
1849 offsetDiffs.push_back(offset - start);
1859 int64_t srcRankDiff =
1860 insertOp.getSourceVectorType().getRank() - destinationRank;
1861 for (int64_t i = 0; i < destinationRank; i++) {
1862 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1863 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1867 extractOp.getVectorMutable().assign(insertOp.getSource());
1870 extractOp.setStaticPosition(offsetDiffs);
1871 return extractOp.getResult();
1875 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1888 if (extractOp.hasDynamicPosition())
1892 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1893 if (!fromElementsOp)
1897 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1898 if (vecType.isScalable())
1902 int64_t rank = vecType.getRank();
1904 if (extractOp.getType() != vecType.getElementType())
1906 assert(
static_cast<int64_t
>(indices.size()) == rank &&
1907 "unexpected number of indices");
1912 for (
int i = rank - 1; i >= 0; --i) {
1913 flatIndex += indices[i] * stride;
1914 stride *= vecType.getDimSize(i);
1916 return fromElementsOp.getElements()[flatIndex];
1923 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
1927 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
1951 Operation *defOp = extractOp.getVector().getDefiningOp();
1952 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1956 if (extractOp.getType() == source.
getType())
1958 auto getRank = [](
Type type) {
1959 return llvm::isa<VectorType>(type)
1960 ? llvm::cast<VectorType>(type).getRank()
1963 unsigned broadcastSrcRank = getRank(source.
getType());
1964 unsigned extractResultRank = getRank(extractOp.getType());
1968 if (extractResultRank < broadcastSrcRank)
1972 if (extractResultRank == 0) {
1973 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
1978 extractOp, extractOp.getType(), source);
1984 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
1992 Value sourceVector = extractOp.getVector();
1996 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1999 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2000 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2008 class ExtractOpNonSplatConstantFolder final
2016 if (extractOp.hasDynamicPosition())
2021 Value sourceVector = extractOp.getVector();
2026 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
2027 if (vecTy.isScalable())
2031 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2032 if (!dense || dense.isSplat())
2038 copy(extractOp.getStaticPosition(), completePositions.begin());
2039 int64_t elemBeginPosition =
2041 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2044 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2046 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2049 newAttr = *denseValuesBegin;
2065 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2069 VectorType extractedMaskType =
2070 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2072 if (!extractedMaskType)
2075 auto maskOperands = createMaskOp.getOperands();
2077 VectorType maskType = createMaskOp.getVectorType();
2079 bool containsUnknownDims =
false;
2082 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2084 int64_t pos = extractOpPos[dimIdx];
2085 Value operand = maskOperands[dimIdx];
2086 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2089 containsUnknownDims =
true;
2093 int64_t createMaskBound =
2094 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2096 if (pos != ShapedType::kDynamic) {
2099 allFalse |= pos >= createMaskBound;
2100 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2104 containsUnknownDims =
true;
2111 }
else if (!containsUnknownDims) {
2113 extractOp, extractedMaskType,
2114 maskOperands.drop_front(extractOpPos.size()));
2124 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2126 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2130 VectorType sourceType = castOp.getSourceVectorType();
2131 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2135 if (sourceType.getNumElements() != targetType.getNumElements())
2139 castOp.getSource());
2149 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2152 if (extractOp.hasDynamicPosition())
2156 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2161 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2162 if (!fromElementsOp)
2164 VectorType inputType = fromElementsOp.getType();
2167 if (resultType.isScalable() || inputType.isScalable())
2173 llvm::to_vector(extractOp.getStaticPosition());
2174 firstElementPos.append(resultType.getRank(), 0);
2177 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2178 flatIndex += firstElementPos[i] * stride;
2179 stride *= inputType.getDimSize(i);
2184 extractOp, resultType,
2185 fromElementsOp.getElements().slice(flatIndex,
2186 resultType.getNumElements()));
2193 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2194 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2195 results.
add(foldExtractFromShapeCastToShapeCast);
2196 results.
add(foldExtractFromFromElements);
2201 for (
auto attr : arrayAttr)
2202 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2209 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2224 if (!llvm::all_equal(fromElementsOp.getElements()))
2227 fromElementsOp.getElements().front());
2245 int64_t rankDiff = dstShape.size() - srcShape.size();
2246 int64_t dstDim = rankDiff;
2248 for (
auto [s1, s2] :
2249 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2251 assert(s1 == 1 &&
"expected dim-1 broadcasting");
2261 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2280 Value BroadcastOp::createOrFoldBroadcastOp(
2283 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2287 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2288 if (broadcastedDims.contains(i))
2290 checkShape.push_back(dstShape[i]);
2292 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2293 "ill-formed broadcastedDims contains values not confined to "
2298 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2302 if (!srcVectorType) {
2303 assert(checkShape.empty() &&
2304 "ill-formed createOrFoldBroadcastOp arguments");
2305 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2308 assert(srcVectorType.getShape().equals(checkShape) &&
2309 "ill-formed createOrFoldBroadcastOp arguments");
2320 broadcastShape.reserve(dstShape.size());
2336 int64_t nextSrcShapeDim = broadcastedDims.size();
2337 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2338 if (broadcastedDims.contains(i)) {
2343 broadcastShape.push_back(dstShape[i]);
2344 permutation[i] = broadcastShape.size() - 1;
2350 permutation[i] = nextSrcShapeDim++;
2354 llvm::append_range(broadcastShape, srcVectorType.getShape());
2359 "unexpected dim-1 broadcast");
2361 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2363 vector::BroadcastableToResult::Success &&
2364 "must be broadcastable");
2368 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2369 if (permutation[i] != i)
2370 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2377 std::pair<int, int> *mismatchingDims) {
2381 return BroadcastableToResult::Success;
2383 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2385 return BroadcastableToResult::SourceTypeNotAVector;
2387 int64_t srcRank = srcVectorType.getRank();
2388 int64_t dstRank = dstVectorType.getRank();
2389 if (srcRank > dstRank)
2390 return BroadcastableToResult::SourceRankHigher;
2393 int64_t lead = dstRank - srcRank;
2394 for (int64_t r = 0; r < srcRank; ++r) {
2395 int64_t srcDim = srcVectorType.getDimSize(r);
2396 int64_t dstDim = dstVectorType.getDimSize(lead + r);
2397 if (srcDim != 1 && srcDim != dstDim) {
2398 if (mismatchingDims) {
2399 mismatchingDims->first = srcDim;
2400 mismatchingDims->second = dstDim;
2402 return BroadcastableToResult::DimensionMismatch;
2406 return BroadcastableToResult::Success;
2410 std::pair<int, int> mismatchingDims;
2412 getSourceType(), getResultVectorType(), &mismatchingDims);
2413 if (res == BroadcastableToResult::Success)
2415 if (res == BroadcastableToResult::SourceRankHigher)
2416 return emitOpError(
"source rank higher than destination rank");
2417 if (res == BroadcastableToResult::DimensionMismatch)
2418 return emitOpError(
"dimension mismatch (")
2419 << mismatchingDims.first <<
" vs. " << mismatchingDims.second <<
")";
2420 if (res == BroadcastableToResult::SourceTypeNotAVector)
2421 return emitOpError(
"source type is not a vector");
2422 llvm_unreachable(
"unexpected vector.broadcast op error");
2426 if (getSourceType() == getResultVectorType())
2428 if (!adaptor.getSource())
2430 auto vectorType = getResultVectorType();
2431 if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2433 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2446 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2450 broadcastOp.getResultVectorType(),
2451 srcBroadcast.getSource());
2461 results.
add<BroadcastFolder>(context);
2474 VectorType resultType = getResultVectorType();
2475 VectorType v1Type = getV1VectorType();
2476 VectorType v2Type = getV2VectorType();
2478 int64_t resRank = resultType.getRank();
2479 int64_t v1Rank = v1Type.getRank();
2480 int64_t v2Rank = v2Type.getRank();
2481 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2482 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2483 if (!wellFormed0DCase && !wellFormedNDCase)
2484 return emitOpError(
"rank mismatch");
2487 for (int64_t r = 1; r < v1Rank; ++r) {
2488 int64_t resDim = resultType.getDimSize(r);
2489 int64_t v1Dim = v1Type.getDimSize(r);
2490 int64_t v2Dim = v2Type.getDimSize(r);
2491 if (resDim != v1Dim || v1Dim != v2Dim)
2492 return emitOpError(
"dimension mismatch");
2495 auto maskAttr = getMask().getValue();
2496 int64_t maskLength = maskAttr.size();
2497 if (maskLength <= 0)
2498 return emitOpError(
"invalid mask length");
2499 if (maskLength != resultType.getDimSize(0))
2500 return emitOpError(
"mask length mismatch");
2502 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2503 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2505 auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2506 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2507 return emitOpError(
"mask index #") << (en.index() + 1) <<
" out of range";
2513 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2514 ShuffleOp::Adaptor adaptor,
2516 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2517 auto v1Rank = v1Type.getRank();
2521 shape.reserve(v1Rank);
2522 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2525 llvm::append_range(shape, v1Type.getShape().drop_front());
2526 inferredReturnTypes.push_back(
2532 uint64_t expected = begin;
2533 return idxArr.size() == width &&
2534 llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2535 [&expected](
auto attr) {
2536 return attr.getZExtValue() == expected++;
2540 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2541 VectorType v1Type = getV1VectorType();
2544 if (v1Type.getRank() == 0)
2548 if (!v1Type.isScalable() &&
2552 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2554 getV2VectorType().getDimSize(0)))
2557 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2562 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).
getType());
2565 if (lhsType.getRank() != 1)
2567 int64_t lhsSize = lhsType.getDimSize(0);
2570 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<
Attribute>();
2571 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<
Attribute>();
2572 for (
const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2573 int64_t i = index.getZExtValue();
2575 results.push_back(rhsElements[i - lhsSize]);
2577 results.push_back(lhsElements[i]);
2593 VectorType v1VectorType = shuffleOp.getV1VectorType();
2594 ArrayAttr mask = shuffleOp.getMask();
2595 if (v1VectorType.getRank() > 0)
2597 if (mask.size() != 1)
2600 if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2617 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2618 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2620 if (!v1Splat || !v2Splat)
2623 if (v1Splat.getInput() != v2Splat.getInput())
2639 VectorType resultType = op.getResultVectorType();
2640 if (resultType.isScalable())
2642 op,
"ShuffleOp can't represent a scalable interleave");
2644 if (resultType.getRank() != 1)
2646 op,
"ShuffleOp can't represent an n-D interleave");
2648 VectorType sourceType = op.getV1VectorType();
2649 if (sourceType != op.getV2VectorType() ||
2650 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2652 op,
"ShuffleOp types don't match an interleave");
2655 ArrayAttr shuffleMask = op.getMask();
2656 int64_t resultVectorSize = resultType.getNumElements();
2657 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2658 int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2659 int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2660 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2662 "ShuffleOp mask not interleaving");
2674 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2684 build(builder, result, source, dest, {});
2688 auto dstVectorType = getDestVectorType();
2689 if (dstVectorType.getRank() == 0) {
2691 return emitOpError(
"expected position to be empty with 0-D vector");
2694 if (dstVectorType.getRank() != 1)
2695 return emitOpError(
"unexpected >1 vector rank");
2697 return emitOpError(
"expected position for 1-D vector");
2701 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2703 if (!adaptor.getPosition())
2706 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2707 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2708 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2709 if (!src || !dst || !pos)
2715 auto dstElements = dst.getValues<
Attribute>();
2719 uint64_t posIdx = pos.getInt();
2720 if (posIdx >= results.size())
2722 results[posIdx] = src;
2732 Value source,
Value dest, int64_t position) {
2745 posVals.reserve(position.size());
2746 llvm::transform(position, std::back_inserter(posVals),
2748 build(builder, result, source, dest, posVals);
2757 build(builder, result, source, dest, dynamicPos,
2763 auto destVectorType = getDestVectorType();
2764 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2766 "expected position attribute of rank no greater than dest vector rank");
2767 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2768 if (srcVectorType &&
2769 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2770 static_cast<unsigned>(destVectorType.getRank())))
2771 return emitOpError(
"expected position attribute rank + source rank to "
2772 "match dest vector rank");
2773 if (!srcVectorType &&
2774 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2776 "expected position attribute rank to match the dest vector rank");
2778 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2779 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2780 if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2781 return emitOpError(
"expected position attribute #")
2783 <<
" to be a non-negative integer smaller than the "
2785 "dest vector dimension";
2802 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2803 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2804 srcVecType.getNumElements())
2807 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2819 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2820 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2822 if (!srcSplat || !dstSplat)
2825 if (srcSplat.getInput() != dstSplat.getInput())
2840 static constexpr int64_t vectorSizeFoldThreshold = 256;
2845 if (op.hasDynamicPosition())
2854 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2858 VectorType destTy = destVector.getType();
2859 if (destTy.isScalable())
2863 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2864 !destVector.hasOneUse())
2867 Value sourceValue = op.getSource();
2875 copy(op.getStaticPosition(), completePositions.begin());
2876 int64_t insertBeginPosition =
2880 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2881 llvm::append_range(insertedValues, denseSource.getValues<
Attribute>());
2883 insertedValues.push_back(sourceCst);
2885 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
2886 copy(insertedValues, allValues.begin() + insertBeginPosition);
2898 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2899 InsertOpConstantFolder>(context);
2905 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2906 if (getNumIndices() == 0)
2930 template <
typename OpType>
2932 ArrayAttr arrayAttr,
2934 StringRef attrName) {
2935 if (arrayAttr.size() > shape.size())
2937 << attrName <<
" attribute of rank no greater than vector rank";
2944 template <
typename OpType>
2945 static LogicalResult
2947 int64_t
max, StringRef attrName,
2948 bool halfOpen =
true) {
2949 for (
auto attr : arrayAttr) {
2950 auto val = llvm::cast<IntegerAttr>(attr).getInt();
2954 if (val < min || val >= upper)
2955 return op.
emitOpError(
"expected ") << attrName <<
" to be confined to ["
2956 <<
min <<
", " << upper <<
")";
2964 template <
typename OpType>
2965 static LogicalResult
2968 bool halfOpen =
true, int64_t
min = 0) {
2969 for (
auto [index, attrDimPair] :
2971 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
2972 int64_t
max = std::get<1>(attrDimPair);
2975 if (val < min || val >=
max)
2977 << attrName <<
" dimension " << index <<
" to be confined to ["
2978 <<
min <<
", " <<
max <<
")";
2988 template <
typename OpType>
2990 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2992 bool halfOpen =
true, int64_t
min = 1) {
2993 assert(arrayAttr1.size() <= shape.size());
2994 assert(arrayAttr2.size() <= shape.size());
2995 for (
auto [index, it] :
2997 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
2998 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
2999 int64_t
max = std::get<2>(it);
3002 if (val1 + val2 < 0 || val1 + val2 >=
max)
3004 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3005 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3012 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3019 auto sourceVectorType = getSourceVectorType();
3020 auto destVectorType = getDestVectorType();
3021 auto offsets = getOffsetsAttr();
3022 auto strides = getStridesAttr();
3023 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3025 "expected offsets of same size as destination vector rank");
3026 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3027 return emitOpError(
"expected strides of same size as source vector rank");
3028 if (sourceVectorType.getRank() > destVectorType.getRank())
3030 "expected source rank to be no greater than destination rank");
3032 auto sourceShape = sourceVectorType.getShape();
3033 auto destShape = destVectorType.getShape();
3035 destShape.size() - sourceShape.size(), 0);
3036 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3037 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3038 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3047 offName,
"source vector shape",
3051 unsigned rankDiff = destShape.size() - sourceShape.size();
3052 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3053 if (sourceVectorType.getScalableDims()[idx] !=
3054 destVectorType.getScalableDims()[idx + rankDiff]) {
3055 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3058 if (sourceVectorType.getScalableDims()[idx]) {
3059 auto sourceSize = sourceShape[idx];
3060 auto destSize = destShape[idx + rankDiff];
3061 if (sourceSize != destSize) {
3062 return emitOpError(
"expected size at idx=")
3064 << (
" to match the corresponding base size from the input "
3066 << sourceSize << (
" vs ") << destSize << (
")");
3077 class FoldInsertStridedSliceSplat final
3082 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3085 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3087 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3089 if (!srcSplatOp || !destSplatOp)
3092 if (srcSplatOp.getInput() != destSplatOp.getInput())
3095 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3102 class FoldInsertStridedSliceOfExtract final
3107 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3109 auto extractStridedSliceOp =
3110 insertStridedSliceOp.getSource()
3111 .getDefiningOp<vector::ExtractStridedSliceOp>();
3113 if (!extractStridedSliceOp)
3116 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3120 if (extractStridedSliceOp.getStrides() !=
3121 insertStridedSliceOp.getStrides() ||
3122 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3125 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3132 class InsertStridedSliceConstantFolder final
3139 static constexpr int64_t vectorSizeFoldThreshold = 256;
3150 VectorType destTy = destVector.getType();
3151 if (destTy.isScalable())
3155 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3156 !destVector.hasOneUse())
3159 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3167 if (op.hasNonUnitStrides())
3170 VectorType sliceVecTy = sourceValue.getType();
3172 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3182 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3183 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3184 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3187 currDestPosition.begin() + rankDifference, currDestPosition.end());
3191 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3192 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3193 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3194 "Invalid slice element");
3195 newValues[linearizedPosition] = *sliceValuesIt;
3208 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3210 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3211 InsertStridedSliceConstantFolder>(context);
3214 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3215 if (getSourceVectorType() == getDestVectorType())
3232 p <<
" " << getLhs() <<
", " << getRhs();
3234 p <<
", " << getAcc();
3237 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3248 if (operandsInfo.size() < 2)
3250 "expected at least 2 operands");
3251 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3252 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3255 "expected vector type for operand #1");
3260 vRHS.getScalableDims()[0]};
3262 vLHS.getElementType(), scalableDimsRes);
3266 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3272 OuterProductOp::getKindAttrName(result.
name),
3274 OuterProductOp::getDefaultKind()));
3280 (operandsInfo.size() > 2 &&
3286 Type tRHS = getOperandTypeRHS();
3287 VectorType vLHS = getOperandVectorTypeLHS(),
3288 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3289 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3291 if (vLHS.getRank() != 1)
3292 return emitOpError(
"expected 1-d vector for operand #1");
3296 if (vRHS.getRank() != 1)
3297 return emitOpError(
"expected 1-d vector for operand #2");
3298 if (vRES.getRank() != 2)
3299 return emitOpError(
"expected 2-d vector result");
3300 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3301 return emitOpError(
"expected #1 operand dim to match result dim #1");
3302 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3303 return emitOpError(
"expected #2 operand dim to match result dim #2");
3304 if (vLHS.isScalable() && !vRHS.isScalable()) {
3308 "expected either both or only #2 operand dim to be scalable");
3312 if (vRES.getRank() != 1)
3313 return emitOpError(
"expected 1-d vector result");
3314 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3315 return emitOpError(
"expected #1 operand dim to match result dim #1");
3318 if (vACC && vACC != vRES)
3319 return emitOpError(
"expected operand #3 of same type as result type");
3323 return emitOpError(
"unsupported outerproduct type");
3332 Type OuterProductOp::getExpectedMaskType() {
3333 auto vecType = this->getResultVectorType();
3336 vecType.getScalableDims());
3345 auto inputVectorType = getInputVectorType();
3346 auto outputVectorType = getOutputVectorType();
3347 int64_t inputShapeRank = getNumInputShapeSizes();
3348 int64_t outputShapeRank = getNumOutputShapeSizes();
3350 getFixedVectorSizes(fixedVectorSizes);
3351 int64_t numFixedVectorSizes = fixedVectorSizes.size();
3353 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
3354 return emitError(
"invalid input shape for vector type ") << inputVectorType;
3356 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
3357 return emitError(
"invalid output shape for vector type ")
3358 << outputVectorType;
3362 unsigned inputVectorRank = inputVectorType.getRank();
3363 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
3364 unsigned index = inputVectorRank - numFixedVectorSizes - i;
3365 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
3366 return emitError(
"fixed vector size must match input vector for dim ")
3370 unsigned outputVectorRank = outputVectorType.getRank();
3371 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
3372 unsigned index = outputVectorRank - numFixedVectorSizes - i;
3373 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
3374 return emitError(
"fixed vector size must match output vector for dim ")
3380 auto isDefByConstant = [](
Value operand) {
3383 if (llvm::all_of(getInputShape(), isDefByConstant) &&
3384 llvm::all_of(getOutputShape(), isDefByConstant)) {
3385 int64_t numInputElements = 1;
3386 for (
auto operand : getInputShape())
3388 int64_t numOutputElements = 1;
3389 for (
auto operand : getOutputShape())
3391 if (numInputElements != numOutputElements)
3392 return emitError(
"product of input and output shape sizes must match");
3410 ArrayAttr offsets, ArrayAttr sizes,
3411 ArrayAttr strides) {
3412 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3414 shape.reserve(vectorType.getRank());
3416 for (
unsigned e = offsets.size(); idx < e; ++idx)
3417 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3418 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3419 shape.push_back(vectorType.getShape()[idx]);
3422 vectorType.getScalableDims());
3435 offsetsAttr, sizesAttr, stridesAttr));
3436 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3440 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3445 auto type = getSourceVectorType();
3446 auto offsets = getOffsetsAttr();
3447 auto sizes = getSizesAttr();
3448 auto strides = getStridesAttr();
3449 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3451 "expected offsets, sizes and strides attributes of same size");
3453 auto shape = type.getShape();
3454 auto offName = getOffsetsAttrName();
3455 auto sizesName = getSizesAttrName();
3456 auto stridesName = getStridesAttrName();
3472 shape, offName, sizesName,
3477 offsets, sizes, strides);
3478 if (getResult().
getType() != resultType)
3479 return emitOpError(
"expected result type to be ") << resultType;
3481 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3482 if (type.getScalableDims()[idx]) {
3483 auto inputDim = type.getShape()[idx];
3484 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3485 if (inputDim != inputSize)
3486 return emitOpError(
"expected size at idx=")
3488 << (
" to match the corresponding base size from the input "
3490 << inputSize << (
" vs ") << inputDim << (
")");
3500 static LogicalResult
3503 auto getElement = [](ArrayAttr array,
int idx) {
3504 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3506 ArrayAttr extractOffsets = op.getOffsets();
3508 ArrayAttr extractSizes = op.getSizes();
3509 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3511 if (op.getSourceVectorType().getRank() !=
3512 insertOp.getSourceVectorType().getRank())
3514 ArrayAttr insertOffsets = insertOp.getOffsets();
3515 ArrayAttr insertStrides = insertOp.getStrides();
3518 if (extractOffsets.size() > insertOffsets.size())
3520 bool patialoverlap =
false;
3521 bool disjoint =
false;
3523 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3524 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3526 int64_t start = getElement(insertOffsets, dim);
3527 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3528 int64_t offset = getElement(extractOffsets, dim);
3529 int64_t size = getElement(extractSizes, dim);
3531 if (start <= offset && offset < end) {
3534 if (offset + size > end)
3535 patialoverlap =
true;
3536 offsetDiffs.push_back(offset - start);
3543 if (!disjoint && !patialoverlap) {
3553 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3563 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3564 if (getSourceVectorType() == getResult().
getType())
3579 class StridedSliceConstantMaskFolder final
3584 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3588 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3589 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3590 if (!constantMaskOp)
3593 if (extractStridedSliceOp.hasNonUnitStrides())
3607 sliceMaskDimSizes.reserve(maskDimSizes.size());
3608 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3609 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3610 int64_t sliceMaskDimSize =
std::max(
3611 static_cast<int64_t
>(0),
3612 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3613 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3616 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3617 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3618 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3621 if (llvm::is_contained(sliceMaskDimSizes, 0))
3622 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3627 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3634 class StridedSliceSplatConstantFolder final
3639 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3643 Value sourceVector = extractStridedSliceOp.getVector();
3648 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3662 class StridedSliceNonSplatConstantFolder final
3667 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3671 Value sourceVector = extractStridedSliceOp.getVector();
3677 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3678 if (!dense || dense.isSplat())
3682 if (extractStridedSliceOp.hasNonUnitStrides())
3685 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3689 VectorType sliceVecTy = extractStridedSliceOp.getType();
3691 int64_t sliceRank = sliceVecTy.getRank();
3703 auto denseValuesBegin = dense.value_begin<
Attribute>();
3705 sliceValues.reserve(sliceVecTy.getNumElements());
3708 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3709 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3711 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3715 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3716 sliceVecTy.getNumElements() &&
3717 "Invalid number of slice elements");
3727 class StridedSliceBroadcast final
3734 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3739 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3740 auto dstVecType = llvm::cast<VectorType>(op.getType());
3741 unsigned dstRank = dstVecType.getRank();
3742 unsigned rankDiff = dstRank - srcRank;
3746 bool lowerDimMatch =
true;
3747 for (
unsigned i = 0; i < srcRank; i++) {
3748 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3749 lowerDimMatch =
false;
3758 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3759 if (!lowerDimMatch && !isScalarSrc) {
3760 source = rewriter.
create<ExtractStridedSliceOp>(
3772 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3778 auto splat = op.getVector().getDefiningOp<SplatOp>();
3788 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3792 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3793 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3794 StridedSliceSplat>(context);
3803 VectorType vectorType,
Value source,
3804 ValueRange indices, AffineMapAttr permutationMapAttr,
3805 ArrayAttr inBoundsAttr) {
3806 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3807 Value padding = builder.
create<arith::ConstantOp>(
3809 build(builder, result, vectorType, source, indices, permutationMapAttr,
3810 padding,
Value(), inBoundsAttr);
3815 VectorType vectorType,
Value source,
3819 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3823 build(builder, result, vectorType, source, indices, permutationMapAttr,
3829 VectorType vectorType,
Value source,
3833 llvm::cast<ShapedType>(source.
getType()), vectorType);
3835 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3839 build(builder, result, vectorType, source, indices, permutationMapAttr,
3841 Value(), inBoundsAttr);
3847 VectorType vectorType,
Value source,
3850 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3851 Value padding = builder.
create<arith::ConstantOp>(
3853 build(builder, result, vectorType, source, indices, padding, inBounds);
3856 template <
typename EmitFun>
3858 EmitFun emitOpError) {
3860 for (
auto expr : permutationMap.
getResults()) {
3861 auto dim = dyn_cast<AffineDimExpr>(expr);
3862 auto zero = dyn_cast<AffineConstantExpr>(expr);
3864 if (zero.getValue() != 0) {
3866 "requires a projected permutation_map (at most one dim or the zero "
3867 "constant can appear in each result)");
3872 return emitOpError(
"requires a projected permutation_map (at most one "
3873 "dim or the zero constant can appear in each result)");
3875 if (seen[dim.getPosition()]) {
3877 "requires a permutation_map that is a permutation (found one dim "
3878 "used more than once)");
3880 seen[dim.getPosition()] =
true;
3885 static LogicalResult
3887 VectorType vectorType, VectorType maskType,
3888 VectorType inferredMaskType,
AffineMap permutationMap,
3889 ArrayAttr inBounds) {
3891 return op->
emitOpError(
"masked attribute has been removed. "
3892 "Use in_bounds instead.");
3895 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3897 "requires source to be a memref or ranked tensor type");
3899 auto elementType = shapedType.getElementType();
3900 DataLayout dataLayout = DataLayout::closest(op);
3901 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3903 unsigned sourceVecSize =
3905 vectorElementType.getShape().back();
3906 unsigned resultVecSize =
3908 vectorType.getShape().back();
3909 if (resultVecSize % sourceVecSize != 0)
3911 "requires the bitwidth of the minor 1-D vector to be an integral "
3912 "multiple of the bitwidth of the minor 1-D vector of the source");
3914 unsigned sourceVecEltRank = vectorElementType.getRank();
3915 unsigned resultVecRank = vectorType.getRank();
3916 if (sourceVecEltRank > resultVecRank)
3918 "requires source vector element and vector result ranks to match.");
3919 unsigned rankOffset = resultVecRank - sourceVecEltRank;
3922 return op->
emitOpError(
"requires a permutation_map with result dims of "
3923 "the same rank as the vector type");
3926 return op->
emitOpError(
"does not support masks with vector element type");
3929 unsigned minorSize =
3930 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3931 unsigned resultVecSize =
3935 "requires the bitwidth of the minor 1-D vector to be an integral "
3936 "multiple of the bitwidth of the source element type");
3940 return op->
emitOpError(
"requires a permutation_map with result dims of "
3941 "the same rank as the vector type");
3945 return op->
emitOpError(
"requires permutation_map without symbols");
3947 if (permutationMap.
getNumInputs() != shapedType.getRank())
3948 return op->
emitOpError(
"requires a permutation_map with input dims of the "
3949 "same rank as the source type");
3951 if (maskType && maskType != inferredMaskType)
3953 << inferredMaskType <<
") and mask operand type (" << maskType
3956 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
3957 return op->
emitOpError(
"expects the in_bounds attr of same rank "
3958 "as permutation_map results: ")
3960 <<
" vs inBounds of size: " << inBounds.size();
3961 for (
unsigned int i = 0, e = permutationMap.
getNumResults(); i < e; ++i)
3962 if (isa<AffineConstantExpr>(permutationMap.
getResult(i)) &&
3963 !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3964 return op->
emitOpError(
"requires broadcast dimensions to be in-bounds");
3971 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3972 if (op.getPermutationMap().isMinorIdentity())
3973 elidedAttrs.push_back(op.getPermutationMapAttrName());
3975 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
3976 elidedAttrs.push_back(op.getInBoundsAttrName());
3981 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
3983 p <<
", " << getMask();
3992 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4015 if (hasMask.succeeded()) {
4022 if (types.size() != 2)
4023 return parser.
emitError(typesLoc,
"requires two types");
4025 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4026 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4027 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4028 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4030 return parser.
emitError(typesLoc,
"requires vector type");
4031 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4038 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4040 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4042 if (!inBoundsAttr) {
4052 if (hasMask.succeeded()) {
4053 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4055 maskInfo.
location,
"does not support masks with vector element type");
4058 "expected the same rank for the vector and the "
4059 "results of the permutation map");
4067 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4069 {1, static_cast<int32_t>(indexInfo.size()), 1,
4070 static_cast<int32_t>(hasMask.succeeded())}));
4076 ShapedType shapedType = getShapedType();
4078 VectorType maskType = getMaskType();
4079 auto paddingType = getPadding().getType();
4080 auto permutationMap = getPermutationMap();
4081 VectorType inferredMaskType =
4084 auto sourceElementType = shapedType.getElementType();
4086 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4087 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4089 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4090 shapedType, vectorType, maskType,
4091 inferredMaskType, permutationMap, getInBounds())))
4094 if (
auto sourceVectorElementType =
4095 llvm::dyn_cast<VectorType>(sourceElementType)) {
4098 if (sourceVectorElementType != paddingType)
4100 "requires source element type and padding type to match.");
4104 if (!VectorType::isValidElementType(paddingType))
4105 return emitOpError(
"requires valid padding vector elemental type");
4108 if (paddingType != sourceElementType)
4110 "requires formal padding and source of the same elemental type");
4114 [&](Twine t) {
return emitOpError(t); });
4121 Type TransferReadOp::getExpectedMaskType() {
4125 template <
typename TransferOp>
4126 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4129 if (op.getShapedType().isDynamicDim(indicesIdx))
4131 Value index = op.getIndices()[indicesIdx];
4133 if (!cstOp.has_value())
4136 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4137 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4139 return cstOp.value() + vectorSize <= sourceSize;
4142 template <
typename TransferOp>
4146 if (op.getTransferRank() == 0)
4148 AffineMap permutationMap = op.getPermutationMap();
4149 bool changed =
false;
4151 newInBounds.reserve(op.getTransferRank());
4152 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4154 if (op.isDimInBounds(i)) {
4155 newInBounds.push_back(
true);
4160 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4161 assert(dimExpr &&
"Broadcast dims must be in-bounds");
4164 newInBounds.push_back(inBounds);
4166 changed |= inBounds;
4176 template <
typename TransferOp>
4178 auto mask = op.getMask();
4185 op.getMaskMutable().clear();
4199 static Value foldRAW(TransferReadOp readOp) {
4200 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4202 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4205 return defWrite.getVector();
4207 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4208 cast<VectorTransferOpInterface>(readOp.getOperation())))
4210 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4216 if (
Value vec = foldRAW(*
this))
4230 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4234 void TransferReadOp::getEffects(
4237 if (llvm::isa<MemRefType>(getShapedType()))
4265 struct TransferReadAfterWriteToBroadcast
4271 if (readOp.hasOutOfBoundsDim() ||
4272 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4274 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4279 if (readOp.getTransferChunkAccessed() !=
4280 defWrite.getTransferChunkAccessed())
4287 if (readOp.getIndices() != defWrite.getIndices() ||
4288 readOp.getMask() != defWrite.getMask())
4290 Value vec = defWrite.getVector();
4312 broadcastShape[pos.value()] = destShape[pos.index()];
4313 broadcastScalableFlags[pos.value()] =
4314 readOp.getVectorType().getScalableDims()[pos.index()];
4317 broadcastShape, defWrite.getVectorType().getElementType(),
4318 broadcastScalableFlags);
4319 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4330 results.
add<TransferReadAfterWriteToBroadcast>(context);
4340 AffineMapAttr permutationMapAttr,
4342 ArrayAttr inBoundsAttr) {
4343 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4344 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4345 mask, inBoundsAttr);
4351 AffineMapAttr permutationMapAttr,
4352 ArrayAttr inBoundsAttr) {
4353 build(builder, result, vector, dest, indices, permutationMapAttr,
4354 Value(), inBoundsAttr);
4365 (inBounds && !inBounds.value().empty())
4368 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4369 build(builder, result, vector, dest, indices, permutationMapAttr,
4370 Value(), inBoundsAttr);
4378 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4380 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4381 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4397 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4402 if (types.size() != 2)
4403 return parser.
emitError(typesLoc,
"requires two types");
4405 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4407 return parser.
emitError(typesLoc,
"requires vector type");
4408 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4409 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4410 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4411 auto permMapAttrName =
4412 TransferWriteOp::getPermutationMapAttrName(result.
name);
4419 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4421 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4423 if (!inBoundsAttr) {
4432 if (hasMask.succeeded()) {
4433 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4435 maskInfo.
location,
"does not support masks with vector element type");
4438 "expected the same rank for the vector and the "
4439 "results of the permutation map");
4445 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4447 {1, 1, static_cast<int32_t>(indexInfo.size()),
4448 static_cast<int32_t>(hasMask.succeeded())}));
4449 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4454 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4456 p <<
", " << getMask();
4463 ShapedType shapedType = getShapedType();
4465 VectorType maskType = getMaskType();
4466 auto permutationMap = getPermutationMap();
4467 VectorType inferredMaskType =
4471 if (llvm::size(
getIndices()) != shapedType.getRank())
4472 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4476 if (hasBroadcastDim())
4477 return emitOpError(
"should not have broadcast dimensions");
4479 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4480 shapedType, vectorType, maskType,
4481 inferredMaskType, permutationMap, getInBounds())))
4485 [&](Twine t) {
return emitOpError(t); });
4492 Type TransferWriteOp::getExpectedMaskType() {
4513 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4517 if (write.getTransferRank() == 0)
4519 auto rankedTensorType =
4520 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4522 if (!rankedTensorType)
4525 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4529 if (read.getTransferRank() == 0)
4532 if (!read.getPermutationMap().isMinorIdentity() ||
4533 !write.getPermutationMap().isMinorIdentity())
4536 if (read.getTransferRank() != write.getTransferRank())
4539 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4542 if (read.getSource().getType() != rankedTensorType)
4545 if (read.getVectorType() != write.getVectorType())
4548 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4551 auto isNotConstantZero = [](
Value v) {
4553 return !cstOp.has_value() || cstOp.value() != 0;
4555 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4556 llvm::any_of(write.getIndices(), isNotConstantZero))
4559 results.push_back(read.getSource());
4563 static bool checkSameValueWAR(vector::TransferReadOp read,
4564 vector::TransferWriteOp write) {
4565 return read.getSource() == write.getSource() &&
4566 read.getIndices() == write.getIndices() &&
4567 read.getPermutationMap() == write.getPermutationMap() &&
4568 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4585 static LogicalResult foldWAR(TransferWriteOp write,
4587 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4589 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4593 if (!checkSameValueWAR(read, write))
4595 results.push_back(read.getSource());
4599 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4601 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4603 if (succeeded(foldWAR(*
this, results)))
4612 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4616 void TransferWriteOp::getEffects(
4619 if (llvm::isa<MemRefType>(getShapedType()))
4654 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4656 vector::TransferWriteOp writeToModify = writeOp;
4659 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4663 writeToModify.getSourceMutable().assign(defWrite.getSource());
4668 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4669 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4673 if (!defWrite->hasOneUse())
4675 writeToModify = defWrite;
4676 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4705 struct SwapExtractSliceOfTransferWrite
4712 if (!insertOp.hasUnitStride())
4715 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4716 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4718 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4719 if (!transferOp || !transferOp->hasOneUse())
4724 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4726 "use-def chain is rank-reducing");
4730 if (!extractOp.hasZeroOffset()) {
4732 "ExtractSliceOp has non-zero offset");
4736 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
4740 "TranferWriteOp has non-zero offset");
4744 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4746 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
4749 for (
auto [insertSize, extractSize] :
4750 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4753 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
4758 assert(transferOp.getVectorType().hasStaticShape() &&
4759 "expected vector to have a static shape");
4762 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4763 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
4765 insertOp,
"TransferWriteOp may not write the full tensor.");
4771 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
4772 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4773 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4774 insertOp.getMixedStrides());
4775 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
4776 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4777 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4780 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4790 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4797 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
4798 MemRefType memRefTy) {
4800 return op->
emitOpError(
"most minor memref dim must have unit stride");
4808 if (failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4812 Type memElemTy = memRefTy.getElementType();
4813 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4814 if (memVecTy != resVecTy)
4815 return emitOpError(
"base memref and result vector types should match");
4816 memElemTy = memVecTy.getElementType();
4819 if (resVecTy.getElementType() != memElemTy)
4820 return emitOpError(
"base and result element types should match");
4821 if (llvm::size(
getIndices()) != memRefTy.getRank())
4822 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4840 if (failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4844 Type memElemTy = memRefTy.getElementType();
4845 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4846 if (memVecTy != valueVecTy)
4848 "base memref and valueToStore vector types should match");
4849 memElemTy = memVecTy.getElementType();
4852 if (valueVecTy.getElementType() != memElemTy)
4853 return emitOpError(
"base and valueToStore element type should match");
4854 if (llvm::size(
getIndices()) != memRefTy.getRank())
4855 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4859 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
4869 VectorType maskVType = getMaskVectorType();
4870 VectorType passVType = getPassThruVectorType();
4874 if (resVType.getElementType() != memType.getElementType())
4875 return emitOpError(
"base and result element type should match");
4876 if (llvm::size(
getIndices()) != memType.getRank())
4877 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4878 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4879 return emitOpError(
"expected result dim to match mask dim");
4880 if (resVType != passVType)
4881 return emitOpError(
"expected pass_thru of same type as result type");
4894 load, load.getType(), load.getBase(), load.getIndices());
4897 rewriter.
replaceOp(load, load.getPassThru());
4902 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
4909 results.
add<MaskedLoadFolder>(context);
4923 VectorType maskVType = getMaskVectorType();
4927 if (valueVType.getElementType() != memType.getElementType())
4928 return emitOpError(
"base and valueToStore element type should match");
4929 if (llvm::size(
getIndices()) != memType.getRank())
4930 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4931 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4932 return emitOpError(
"expected valueToStore dim to match mask dim");
4945 store, store.getValueToStore(), store.getBase(), store.getIndices());
4953 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
4960 results.
add<MaskedStoreFolder>(context);
4963 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
4973 VectorType indVType = getIndexVectorType();
4974 VectorType maskVType = getMaskVectorType();
4976 ShapedType baseType = getBaseType();
4978 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4979 return emitOpError(
"requires base to be a memref or ranked tensor type");
4981 if (resVType.getElementType() != baseType.getElementType())
4982 return emitOpError(
"base and result element type should match");
4983 if (llvm::size(
getIndices()) != baseType.getRank())
4984 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
4985 if (resVType.getShape() != indVType.getShape())
4986 return emitOpError(
"expected result dim to match indices dim");
4987 if (resVType.getShape() != maskVType.getShape())
4988 return emitOpError(
"expected result dim to match mask dim");
4989 if (resVType != getPassThruVectorType())
4990 return emitOpError(
"expected pass_thru of same type as result type");
4998 Type GatherOp::getExpectedMaskType() {
4999 auto vecType = this->getIndexVectorType();
5002 vecType.getScalableDims());
5005 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5019 rewriter.
replaceOp(gather, gather.getPassThru());
5024 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5031 results.
add<GatherFolder>(context);
5039 VectorType indVType = getIndexVectorType();
5040 VectorType maskVType = getMaskVectorType();
5044 if (valueVType.getElementType() != memType.getElementType())
5045 return emitOpError(
"base and valueToStore element type should match");
5046 if (llvm::size(
getIndices()) != memType.getRank())
5047 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5048 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5049 return emitOpError(
"expected valueToStore dim to match indices dim");
5050 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5051 return emitOpError(
"expected valueToStore dim to match mask dim");
5070 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5077 results.
add<ScatterFolder>(context);
5085 VectorType maskVType = getMaskVectorType();
5086 VectorType passVType = getPassThruVectorType();
5090 if (resVType.getElementType() != memType.getElementType())
5091 return emitOpError(
"base and result element type should match");
5092 if (llvm::size(
getIndices()) != memType.getRank())
5093 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5094 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5095 return emitOpError(
"expected result dim to match mask dim");
5096 if (resVType != passVType)
5097 return emitOpError(
"expected pass_thru of same type as result type");
5110 expand, expand.getType(), expand.getBase(), expand.getIndices());
5113 rewriter.
replaceOp(expand, expand.getPassThru());
5118 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5125 results.
add<ExpandLoadFolder>(context);
5133 VectorType maskVType = getMaskVectorType();
5137 if (valueVType.getElementType() != memType.getElementType())
5138 return emitOpError(
"base and valueToStore element type should match");
5139 if (llvm::size(
getIndices()) != memType.getRank())
5140 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5141 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5142 return emitOpError(
"expected valueToStore dim to match mask dim");
5147 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5155 compress, compress.getValueToStore(), compress.getBase(),
5156 compress.getIndices());
5164 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5171 results.
add<CompressStoreFolder>(context);
5181 unsigned rankA = a.size();
5182 unsigned rankB = b.size();
5183 assert(rankA < rankB);
5185 auto isOne = [](int64_t v) {
return v == 1; };
5189 if (rankA == 0 && llvm::all_of(b, isOne))
5194 while (i < rankA &&
j < rankB) {
5195 int64_t dimA = a[i];
5197 while (dimB < dimA &&
j < rankB)
5205 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5207 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5211 return i == rankA &&
j == rankB;
5214 static LogicalResult verifyVectorShapeCast(
Operation *op,
5215 VectorType sourceVectorType,
5216 VectorType resultVectorType) {
5218 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5219 return op->
emitOpError(
"source/result vectors must have same element type");
5220 auto sourceShape = sourceVectorType.getShape();
5221 auto resultShape = resultVectorType.getShape();
5224 int64_t sourceDimProduct = std::accumulate(
5225 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5226 int64_t resultDimProduct = std::accumulate(
5227 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5228 if (sourceDimProduct != resultDimProduct)
5229 return op->
emitOpError(
"source/result number of elements must match");
5232 unsigned sourceRank = sourceVectorType.getRank();
5233 unsigned resultRank = resultVectorType.getRank();
5234 if (sourceRank < resultRank) {
5235 if (!isValidShapeCast(sourceShape, resultShape))
5237 }
else if (sourceRank > resultRank) {
5238 if (!isValidShapeCast(resultShape, sourceShape))
5243 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5244 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5245 if (sourceNScalableDims != resultNScalableDims)
5246 return op->
emitOpError(
"different number of scalable dims at source (")
5247 << sourceNScalableDims <<
") and result (" << resultNScalableDims
5249 sourceVectorType.getNumDynamicDims();
5255 auto sourceVectorType =
5256 llvm::dyn_cast_or_null<VectorType>(getSource().
getType());
5257 auto resultVectorType =
5258 llvm::dyn_cast_or_null<VectorType>(getResult().
getType());
5261 if (sourceVectorType && resultVectorType)
5262 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5273 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5274 if (getResult().
getType() == otherOp.getSource().getType())
5275 return otherOp.getSource();
5278 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5279 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
5280 if (srcType.getRank() < resultType.getRank()) {
5281 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5283 }
else if (srcType.getRank() > resultType.getRank()) {
5284 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5290 setOperand(otherOp.getSource());
5295 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5296 if (bcastOp.getSourceType() ==
getType())
5297 return bcastOp.getSource();
5305 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5312 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5316 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5332 static VectorType trimTrailingOneDims(VectorType oldType) {
5339 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5340 newShape = newShape.drop_back(1);
5341 newScalableDims = newScalableDims.drop_back(1);
5346 if (newShape.empty()) {
5347 newShape = oldShape.take_back();
5348 newScalableDims = oldScalableDims.take_back();
5351 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5366 class ShapeCastCreateMaskFolderTrailingOneDim final
5373 Value shapeOpSrc = shapeOp->getOperand(0);
5374 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5375 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5376 if (!createMaskOp && !constantMaskOp)
5379 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5380 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5382 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5383 if (newVecType != shapeOpResTy)
5386 auto numDimsToDrop =
5387 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5394 auto maskOperands = createMaskOp.getOperands();
5395 auto numMaskOperands = maskOperands.size();
5398 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5400 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5401 if (!constant || (constant.value() != 1))
5405 maskOperands.drop_back(numDimsToDrop);
5412 if (constantMaskOp) {
5413 auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
5414 auto numMaskOperands = maskDimSizes.size();
5417 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5419 if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
5423 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5424 ArrayAttr newMaskOperandsAttr = rewriter.
getArrayAttr(newMaskOperands);
5427 newMaskOperandsAttr);
5440 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5447 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5452 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5453 broadcastSourceShape = srcType.getShape();
5455 shapeCastOp.getResultVectorType().getShape();
5459 if (broadcastSourceShape ==
5460 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5462 shapeCastOp, shapeCastOp.getResultVectorType(),
5463 broadcastOp.getSource());
5469 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5470 if (srcType.getNumElements() ==
5471 shapeCastOp.getResultVectorType().getNumElements()) {
5473 shapeCastOp, shapeCastOp.getResultVectorType(),
5474 broadcastOp.getSource());
5487 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5488 ShapeCastBroadcastFolder>(context);
5496 auto sourceVectorType = getSourceVectorType();
5497 auto resultVectorType = getResultVectorType();
5499 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5500 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5501 return emitOpError(
"dimension size mismatch at: ") << i;
5504 DataLayout dataLayout = DataLayout::closest(*
this);
5505 auto sourceElementBits =
5507 auto resultElementBits =
5510 if (sourceVectorType.getRank() == 0) {
5511 if (sourceElementBits != resultElementBits)
5512 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5513 "types must be equal");
5514 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5515 resultElementBits * resultVectorType.getShape().back()) {
5517 "source/result bitwidth of the minor 1-D vectors must be equal");
5529 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5530 if (getResult().
getType() == otherOp.getSource().getType())
5531 return otherOp.getSource();
5533 setOperand(otherOp.getSource());
5537 Attribute sourceConstant = adaptor.getSource();
5538 if (!sourceConstant)
5541 Type srcElemType = getSourceVectorType().getElementType();
5542 Type dstElemType = getResultVectorType().getElementType();
5544 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5545 if (floatPack.isSplat()) {
5546 auto splat = floatPack.getSplatValue<FloatAttr>();
5549 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5550 uint32_t bits =
static_cast<uint32_t
>(
5551 splat.getValue().bitcastToAPInt().getZExtValue());
5553 bits = (bits << 16) | (bits & 0xffff);
5554 APInt intBits(32, bits);
5555 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5561 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5562 if (intPack.isSplat()) {
5563 auto splat = intPack.getSplatValue<IntegerAttr>();
5565 if (llvm::isa<IntegerType>(dstElemType)) {
5570 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5571 APInt intBits = splat.getValue().zext(dstBitWidth);
5574 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5575 intBits = (intBits << srcBitWidth) | intBits;
5590 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5592 memRefType.getShape().end());
5594 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5603 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5604 VectorType vectorType =
5608 memRefType.getMemorySpace()));
5613 if (!canonicalType.getLayout().isIdentity())
5614 return emitOpError(
"expects operand to be a memref with identity layout");
5615 if (!getResultMemRefType().getLayout().isIdentity())
5616 return emitOpError(
"expects result to be a memref with identity layout");
5617 if (getResultMemRefType().getMemorySpace() !=
5619 return emitOpError(
"expects result in same memory space");
5622 auto resultType = getResultMemRefType();
5626 "expects result and operand with same underlying scalar type: ")
5628 if (extractShape(sourceType) != extractShape(resultType))
5630 "expects concatenated result and operand shapes to be equal: ")
5641 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5644 for (
unsigned i = 0; i < permutation.size(); ++i) {
5645 transposedShape[i] = vt.getShape()[permutation[i]];
5646 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5651 transposedScalableDims));
5656 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5659 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5661 return attr.reshape(getResultVectorType());
5669 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5678 VectorType vectorType = getSourceVectorType();
5679 VectorType resultType = getResultVectorType();
5680 int64_t rank = resultType.getRank();
5681 if (vectorType.getRank() != rank)
5682 return emitOpError(
"vector result rank mismatch: ") << rank;
5685 int64_t size = perm.size();
5687 return emitOpError(
"transposition length mismatch: ") << size;
5690 if (ta.value() < 0 || ta.value() >= rank)
5691 return emitOpError(
"transposition index out of range: ") << ta.value();
5692 if (seen[ta.value()])
5693 return emitOpError(
"duplicate position index: ") << ta.value();
5694 seen[ta.value()] =
true;
5695 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5696 return emitOpError(
"dimension size mismatch at: ") << ta.value();
5701 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5702 return llvm::to_vector<4>(getResultVectorType().
getShape());
5708 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
5718 for (
auto index : permutation2)
5719 result.push_back(permutation1[index]);
5724 vector::TransposeOp parentTransposeOp =
5725 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5726 if (!parentTransposeOp)
5730 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5733 transposeOp, transposeOp.getResult().getType(),
5734 parentTransposeOp.getVector(), permutation);
5740 struct FoldTransposedScalarBroadcast final
5746 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5750 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5751 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5753 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5768 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5773 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5779 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
5785 Value transposeSrc = transpOp.getVector();
5786 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
5787 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
5788 if (!createMaskOp && !constantMaskOp)
5796 auto maskOperands = createMaskOp.getOperands();
5801 transpOp, transpOp.getResultVectorType(), newOperands);
5806 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5811 transpOp, transpOp.getResultVectorType(),
5819 void vector::TransposeOp::getCanonicalizationPatterns(
5821 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5822 TransposeFolder, FoldTransposeSplat>(context);
5830 auto resultType = llvm::cast<VectorType>(getResult().
getType());
5832 if (resultType.getRank() == 0) {
5833 if (getMaskDimSizes().size() != 1)
5834 return emitError(
"array attr must have length 1 for 0-D vectors");
5835 auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5836 if (dim != 0 && dim != 1)
5837 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
5842 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
5844 "must specify array attr of size equal vector result rank");
5847 auto resultShape = resultType.getShape();
5848 auto resultScalableDims = resultType.getScalableDims();
5850 for (
const auto [index, intAttr] :
llvm::enumerate(getMaskDimSizes())) {
5851 int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5852 if (maskDimSize < 0 || maskDimSize > resultShape[index])
5854 "array attr of size out of bounds of vector result dimension size");
5855 if (resultScalableDims[index] && maskDimSize != 0 &&
5856 maskDimSize != resultShape[index])
5858 "only supports 'none set' or 'all set' scalable dimensions");
5859 maskDimSizes.push_back(maskDimSize);
5863 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5864 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
5865 if (anyZeros && !allZeros)
5866 return emitOpError(
"expected all mask dim sizes to be zeros, "
5867 "as a result of conjunction with zero mask dim");
5871 bool ConstantMaskOp::isAllOnesMask() {
5874 if (resultType.getRank() == 0) {
5875 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
5876 return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5878 for (
const auto [resultSize, intAttr] :
5879 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5880 int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5881 if (maskDimSize < resultSize)
5896 build(builder, result, type, operands);
5900 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
5902 if (vectorType.getRank() == 0) {
5903 if (getNumOperands() != 1)
5905 "must specify exactly one operand for 0-D create_mask");
5906 }
else if (getNumOperands() !=
5907 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
5909 "must specify an operand for each result vector dimension");
5945 VectorType retTy = createMaskOp.getResult().getType();
5946 bool isScalable = retTy.isScalable();
5949 for (
auto [opIdx, operand] :
llvm::enumerate(createMaskOp.getOperands())) {
5954 if (retTy.getScalableDims()[opIdx] && *cst > 0)
5969 auto mulLHS = mul.getRhs();
5970 auto mulRHS = mul.getLhs();
5971 bool isOneOpVscale =
5972 (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5973 isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5975 auto isConstantValMatchingDim =
5976 [=, dim = retTy.getShape()[opIdx]](
Value operand) {
5978 return (constantVal.has_value() && constantVal.value() == dim);
5981 bool isOneOpConstantMatchingDim =
5982 isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5984 if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5990 maskDimSizes.reserve(createMaskOp->getNumOperands());
5991 for (
auto [operand, maxDimSize] : llvm::zip_equal(
5992 createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5997 maskDimSizes.push_back(maxDimSize);
6000 int64_t dimSizeVal =
std::min(dimSize.value(), maxDimSize);
6003 maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
6006 maskDimSizes.push_back(dimSizeVal);
6011 createMaskOp, retTy,
6021 results.
add<CreateMaskFolder>(context);
6032 assert(maskRegionBuilder &&
6033 "builder callback for 'maskRegion' must be present");
6039 maskRegionBuilder(builder, maskableOp);
6046 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6054 build(builder, result, mask, maskableOp, maskRegionBuilder);
6075 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6082 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6096 result.
types.append(resultTypes);
6102 if (parsePassthru.succeeded())
6110 p <<
" " << getMask();
6112 p <<
", " << getPassthru();
6116 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6123 p <<
" : " << getMask().getType();
6124 if (getNumResults() > 0)
6125 p <<
" -> " << getResultTypes();
6130 MaskOp>::ensureTerminator(region, builder, loc);
6142 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
6145 if (maskedOp == oldYieldOp)
6148 opBuilder.setInsertionPoint(oldYieldOp);
6149 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
6151 oldYieldOp->
erase();
6156 Block &block = getMaskRegion().getBlocks().
front();
6158 return emitOpError(
"expects a terminator within the mask region");
6160 return emitOpError(
"expects only one operation to mask");
6163 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6165 return emitOpError(
"expects a terminator within the mask region");
6167 if (terminator->getNumOperands() != getNumResults())
6169 "expects number of results to match mask region yielded values");
6171 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6178 return emitOpError(
"expects number of results to match maskable operation "
6179 "number of results");
6181 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6183 "expects result type to match maskable operation result type");
6186 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6187 return emitOpError(
"multiple vector results not supported");
6190 Type expectedMaskType = maskableOp.getExpectedMaskType();
6191 if (getMask().
getType() != expectedMaskType)
6192 return emitOpError(
"expects a ")
6193 << expectedMaskType <<
" mask for the maskable operation";
6196 Value passthru = getPassthru();
6198 if (!maskableOp.supportsPassthru())
6200 "doesn't expect a passthru argument for this maskable operation");
6203 return emitOpError(
"expects result when passthru argument is provided");
6206 return emitOpError(
"expects passthru type to match result type");
6213 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6223 Operation *maskableOp = getMaskableOp();
6227 llvm::append_range(results, maskableOp->
getResults());
6239 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6240 if (maskingOp.getMaskableOp())
6243 if (!maskOp.isEmpty())
6246 Block *block = maskOp.getMaskBlock();
6247 auto terminator = cast<vector::YieldOp>(block->
front());
6248 if (terminator.getNumOperands() == 0)
6251 rewriter.
replaceOp(maskOp, terminator.getOperands());
6259 results.
add<ElideEmptyMaskOp>(context);
6266 Block *block = getMaskBlock();
6270 return &block->
front();
6274 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
6281 VectorType srcType = getSourceType();
6282 VectorType initialType = getInitialValueType();
6284 int64_t srcRank = srcType.getRank();
6285 int64_t reductionDim = getReductionDim();
6286 if (reductionDim >= srcRank)
6287 return emitOpError(
"reduction dimension ")
6288 << reductionDim <<
" has to be less than " << srcRank;
6291 int64_t initialValueRank = initialType.getRank();
6292 if (initialValueRank != srcRank - 1)
6293 return emitOpError(
"initial value rank ")
6294 << initialValueRank <<
" has to be equal to " << srcRank - 1;
6300 for (
int i = 0; i < srcRank; i++) {
6301 if (i != reductionDim)
6302 expectedShape.push_back(srcShape[i]);
6304 if (!llvm::equal(initialValueShapes, expectedShape)) {
6305 return emitOpError(
"incompatible input/initial value shapes");
6309 Type eltType = getDestType().getElementType();
6311 return emitOpError(
"unsupported reduction type ")
6312 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
6321 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6322 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6323 StridedSliceConstantMaskFolder, TransposeFolder>(
6332 auto constOperand = adaptor.getInput();
6333 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6345 auto resultType = cast<VectorType>(
getType());
6346 if (resultType.isScalable())
6349 for (
unsigned i = 0; i < resultType.getNumElements(); i++)
6350 indices.push_back(APInt(64, i));
6359 p <<
"(" << getLaneid() <<
")";
6362 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
6363 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
6365 if (!getArgs().empty())
6366 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
6367 if (!getResults().empty())
6368 p <<
" -> (" << getResults().getTypes() <<
')';
6372 !getResults().empty());
6402 llvm::SMLoc inputsOperandsLoc;
6414 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
6425 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.
location);
6433 void WarpExecuteOnLane0Op::getSuccessorRegions(
6447 build(builder, result, resultTypes, laneId, warpSize,
6448 std::nullopt, std::nullopt);
6460 assert(args.size() == blockArgTypes.size());
6464 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
6470 static LogicalResult verifyDistributedType(
Type expanded,
Type distributed,
6473 if (expanded == distributed)
6475 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6476 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6477 if (!expandedVecType || !distributedVecType)
6478 return op->
emitOpError(
"expected vector type for distributed operands.");
6479 if (expandedVecType.getRank() != distributedVecType.getRank() ||
6480 expandedVecType.getElementType() != distributedVecType.getElementType())
6482 "expected distributed vectors to have same rank and element type.");
6485 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
6486 int64_t eDim = expandedVecType.getDimSize(i);
6487 int64_t dDim = distributedVecType.getDimSize(i);
6490 if (eDim % dDim != 0)
6492 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
6493 <<
") to be a multipler of the distributed vector dimension ("
6495 scales[i] = eDim / dDim;
6497 if (std::accumulate(scales.begin(), scales.end(), 1,
6498 std::multiplies<int64_t>()) != warpSize)
6500 <<
"incompatible distribution dimensions from " << expandedVecType
6501 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
6507 if (getArgs().size() != getWarpRegion().getNumArguments())
6509 "expected same number op arguments and block arguments.");
6511 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
6512 if (yield.getNumOperands() != getNumResults())
6514 "expected same number of yield operands and return values.");
6515 int64_t warpSize = getWarpSize();
6516 for (
auto [regionArg, arg] :
6517 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
6518 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
6519 warpSize, getOperation())))
6522 for (
auto [yieldOperand, result] :
6523 llvm::zip_equal(yield.getOperands(), getResults())) {
6524 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
6525 warpSize, getOperation())))
6531 bool WarpExecuteOnLane0Op::areTypesCompatible(
Type lhs,
Type rhs) {
6533 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
6538 arith::FastMathFlagsAttr fastmath,
6545 case CombiningKind::ADD:
6548 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6549 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6551 llvm_unreachable(
"invalid value types for ADD reduction");
6553 case CombiningKind::AND:
6557 case CombiningKind::MAXNUMF:
6558 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6559 "expected float values");
6560 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6562 case CombiningKind::MAXIMUMF:
6563 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6564 "expected float values");
6565 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6567 case CombiningKind::MINNUMF:
6568 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6569 "expected float values");
6570 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6572 case CombiningKind::MINIMUMF:
6573 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6574 "expected float values");
6575 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6577 case CombiningKind::MAXSI:
6581 case CombiningKind::MINSI:
6585 case CombiningKind::MAXUI:
6593 case CombiningKind::MUL:
6596 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6597 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6599 llvm_unreachable(
"invalid value types for MUL reduction");
6601 case CombiningKind::OR:
6605 case CombiningKind::XOR:
6611 assert(result &&
"unknown CombiningKind");
6623 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
6643 return builder.
create<MaskOp>(maskableOp->getLoc(),
6644 maskableOp->getResultTypes(), mask, maskableOp,
6661 mask, newValue, passthru);
6668 #define GET_ATTRDEF_CLASSES
6669 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6671 #define GET_OP_CLASSES
6672 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static VectorShape vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.