42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/Repeated.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/SmallVector.h"
46#include "llvm/ADT/SmallVectorExtras.h"
47#include "llvm/ADT/StringSet.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/Support/Casting.h"
55#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
57#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
78 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
80 for (
bool b : denseElts.getValues<
bool>())
83 else if (!
b && val <= 0)
97 auto shape = m.getType().getShape();
100 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
101 if (maskIdx < dimSize)
114 auto maskOperands = m.getOperands();
115 for (
Value operand : maskOperands) {
116 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
118 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
131 vector::YieldOp::create(builder, loc);
137 switch (combiningKind) {
138 case CombiningKind::ADD:
139 case CombiningKind::MUL:
141 case CombiningKind::MINUI:
142 case CombiningKind::MINSI:
143 case CombiningKind::MAXUI:
144 case CombiningKind::MAXSI:
145 case CombiningKind::AND:
146 case CombiningKind::OR:
147 case CombiningKind::XOR:
149 case CombiningKind::MINNUMF:
150 case CombiningKind::MAXNUMF:
151 case CombiningKind::MINIMUMF:
152 case CombiningKind::MAXIMUMF:
153 return llvm::isa<FloatType>(elementType);
183 VectorType vectorType) {
184 unsigned elementVectorRank = 0;
185 VectorType elementVectorType =
186 llvm::dyn_cast<VectorType>(shapedType.getElementType());
187 if (elementVectorType)
188 elementVectorRank += elementVectorType.getRank();
189 return vectorType.getRank() - elementVectorRank;
193 VectorType vectorType) {
196 if (shapedType.getRank() == 0 &&
202 shapedType.getRank(),
204 shapedType.getContext());
211 vector::TransferReadOp read) {
212 auto readMask = read.getMask();
213 auto writeMask = write.getMask();
219 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
220 if (!couldBeSameSplat)
237 vector::TransferReadOp read) {
238 return !defWrite.hasOutOfBoundsDim() &&
239 defWrite.getIndices() == read.getIndices() &&
240 defWrite.getVectorType() == read.getVectorType() &&
241 defWrite.getPermutationMap() == read.getPermutationMap() &&
242 ((!defWrite.getMask() && !read.getMask()) ||
247 vector::TransferWriteOp priorWrite) {
248 return priorWrite.getIndices() == write.getIndices() &&
249 priorWrite.getMask() == write.getMask() &&
250 priorWrite.getVectorType() == write.getVectorType() &&
251 priorWrite.getPermutationMap() == write.getPermutationMap();
255 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
256 bool testDynamicValueUsingBounds) {
258 if (transferA.getVectorType() != transferB.getVectorType())
260 unsigned rankOffset = transferA.getLeadingShapedRank();
261 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
262 Value indexA = transferA.getIndices()[i];
263 Value indexB = transferB.getIndices()[i];
267 if (i < rankOffset) {
270 if (cstIndexA.has_value() && cstIndexB.has_value()) {
271 if (*cstIndexA != *cstIndexB)
275 if (testDynamicValueUsingBounds) {
278 FailureOr<uint64_t> delta =
280 if (succeeded(delta) && *delta != 0)
283 FailureOr<bool> testEqual =
285 if (succeeded(testEqual) && !testEqual.value())
291 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
292 if (cstIndexA.has_value() && cstIndexB.has_value()) {
293 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
294 if (distance >= vectorDim)
298 if (testDynamicValueUsingBounds) {
301 FailureOr<int64_t> delta =
303 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
306 FailureOr<int64_t> computeDelta =
308 if (succeeded(computeDelta)) {
309 if (std::abs(computeDelta.value()) >= vectorDim)
319 VectorTransferOpInterface transferB,
320 bool testDynamicValueUsingBounds) {
321 if (transferA.getBase() != transferB.getBase())
324 testDynamicValueUsingBounds);
334 for (
auto [posInDim, dimSize, offsetInDim] :
335 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
337 if (posInDim < dimSize + offsetInDim)
341 posInDim = offsetInDim;
351 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
353 assert(constOp &&
"Unexpected non-constant index");
354 return constOp.value();
364 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
365 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
366 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
376 llvm::transform(foldResults, std::back_inserter(values),
378 if (
auto attr = dyn_cast<Attribute>(foldResult))
380 builder, loc, cast<IntegerAttr>(attr).getInt())
383 return cast<Value>(foldResult);
396 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
398 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
408 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
409 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
410 if (intAttr.getType() != expectedType)
411 return IntegerAttr::get(expectedType, intAttr.getInt());
417 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
418 auto intType = dyn_cast<IntegerType>(expectedType);
422 APFloat floatVal = floatAttr.getValue();
423 APInt intVal = floatVal.bitcastToAPInt();
424 return IntegerAttr::get(expectedType, intVal);
463struct VectorInlinerInterface :
public DialectInlinerInterface {
464 using DialectInlinerInterface::DialectInlinerInterface;
473void VectorDialect::initialize() {
475#define GET_ATTRDEF_LIST
476#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
481#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
484 addInterfaces<VectorInlinerInterface>();
486 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
487 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
489 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
491 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
492 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
493 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
504 return arith::ConstantOp::materialize(builder, value, type, loc);
520void vector::MultiDimReductionOp::build(
OpBuilder &builder,
523 CombiningKind kind) {
525 for (
const auto &en : llvm::enumerate(reductionMask))
527 reductionDims.push_back(en.index());
528 build(builder,
result, kind, source,
acc, reductionDims);
531OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
533 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
538std::optional<SmallVector<int64_t, 4>>
539MultiDimReductionOp::getShapeForUnroll() {
540 return llvm::to_vector<4>(getSourceVectorType().
getShape());
543LogicalResult MultiDimReductionOp::verify() {
546 Type inferredReturnType;
547 auto sourceScalableDims = getSourceVectorType().getScalableDims();
548 for (
auto [dimIdx, dimSize] :
549 llvm::enumerate(getSourceVectorType().
getShape()))
550 if (!llvm::any_of(getReductionDims(),
551 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
552 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
554 targetShape.push_back(dimSize);
555 scalableDims.push_back(sourceScalableDims[dimIdx]);
558 if (targetShape.empty())
559 inferredReturnType = getSourceVectorType().getElementType();
561 inferredReturnType = VectorType::get(
562 targetShape, getSourceVectorType().
getElementType(), scalableDims);
563 if (
getType() != inferredReturnType)
565 <<
" is incompatible with source type "
566 << getSourceVectorType();
572Type MultiDimReductionOp::getExpectedMaskType() {
573 auto vecType = getSourceVectorType();
574 return VectorType::get(vecType.getShape(),
575 IntegerType::get(vecType.getContext(), 1),
576 vecType.getScalableDims());
585struct ElideUnitDimsInMultiDimReduction
589 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
590 PatternRewriter &rewriter)
const override {
591 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
592 for (
const auto &dim :
enumerate(shape)) {
593 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
598 OpBuilder::InsertionGuard guard(rewriter);
601 if (reductionOp.isMasked()) {
603 rootOp = reductionOp.getMaskingOp();
604 mask = reductionOp.getMaskingOp().getMask();
606 rootOp = reductionOp;
609 Location loc = reductionOp.getLoc();
610 Value acc = reductionOp.getAcc();
612 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
614 VectorType newMaskType =
615 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
616 dstVecType.getScalableDims());
617 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
619 cast = vector::ShapeCastOp::create(
620 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
625 mask = vector::ExtractOp::create(rewriter, loc, mask);
626 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
631 cast,
nullptr, mask);
638void MultiDimReductionOp::getCanonicalizationPatterns(
640 results.
add<ElideUnitDimsInMultiDimReduction>(context);
649 arith::FastMathFlags fastMathFlags) {
655 arith::FastMathFlags fastMathFlags) {
657 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
661LogicalResult ReductionOp::verify() {
663 int64_t rank = getSourceVectorType().getRank();
665 return emitOpError(
"unsupported reduction rank: ") << rank;
668 Type eltType = getDest().getType();
671 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
680Type ReductionOp::getExpectedMaskType() {
681 auto vecType = getSourceVectorType();
682 return VectorType::get(vecType.getShape(),
683 IntegerType::get(vecType.getContext(), 1),
684 vecType.getScalableDims());
691 case arith::AtomicRMWKind::addf:
692 case arith::AtomicRMWKind::addi:
693 return vector::ReductionOp::create(builder,
vector.getLoc(),
694 CombiningKind::ADD,
vector);
695 case arith::AtomicRMWKind::mulf:
696 case arith::AtomicRMWKind::muli:
697 return vector::ReductionOp::create(builder,
vector.getLoc(),
698 CombiningKind::MUL,
vector);
699 case arith::AtomicRMWKind::minimumf:
700 return vector::ReductionOp::create(builder,
vector.getLoc(),
701 CombiningKind::MINIMUMF,
vector);
702 case arith::AtomicRMWKind::mins:
703 return vector::ReductionOp::create(builder,
vector.getLoc(),
704 CombiningKind::MINSI,
vector);
705 case arith::AtomicRMWKind::minu:
706 return vector::ReductionOp::create(builder,
vector.getLoc(),
707 CombiningKind::MINUI,
vector);
708 case arith::AtomicRMWKind::maximumf:
709 return vector::ReductionOp::create(builder,
vector.getLoc(),
710 CombiningKind::MAXIMUMF,
vector);
711 case arith::AtomicRMWKind::maxs:
712 return vector::ReductionOp::create(builder,
vector.getLoc(),
713 CombiningKind::MAXSI,
vector);
714 case arith::AtomicRMWKind::maxu:
715 return vector::ReductionOp::create(builder,
vector.getLoc(),
716 CombiningKind::MAXUI,
vector);
717 case arith::AtomicRMWKind::andi:
718 return vector::ReductionOp::create(builder,
vector.getLoc(),
719 CombiningKind::AND,
vector);
720 case arith::AtomicRMWKind::ori:
721 return vector::ReductionOp::create(builder,
vector.getLoc(),
722 CombiningKind::OR,
vector);
723 case arith::AtomicRMWKind::minnumf:
724 return vector::ReductionOp::create(builder,
vector.getLoc(),
725 CombiningKind::MINNUMF,
vector);
726 case arith::AtomicRMWKind::maxnumf:
727 return vector::ReductionOp::create(builder,
vector.getLoc(),
728 CombiningKind::MAXNUMF,
vector);
729 case arith::AtomicRMWKind::xori:
730 return vector::ReductionOp::create(builder,
vector.getLoc(),
731 CombiningKind::XOR,
vector);
739std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
740 return llvm::to_vector<4>(getSourceVectorType().
getShape());
747 LogicalResult matchAndRewrite(ReductionOp reductionOp,
752 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
755 if (maskableOp.isMasked()) {
757 rootOp = maskableOp.getMaskingOp();
758 mask = maskableOp.getMaskingOp().getMask();
760 rootOp = reductionOp;
763 auto vectorType = reductionOp.getSourceVectorType();
764 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
767 Location loc = reductionOp.getLoc();
769 mask = ExtractOp::create(rewriter, loc, mask);
770 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
772 if (Value acc = reductionOp.getAcc())
775 reductionOp.getFastmathAttr(), mask);
785 results.
add<ElideSingleElementReduction>(context);
799 getIndexingMapsAttrName(
result.name),
803 getIteratorTypesAttrName(
result.name),
806 return IteratorTypeAttr::get(builder.getContext(), t);
815 ContractionOp::getDefaultKind());
821 ArrayAttr iteratorTypes, CombiningKind kind) {
824 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
825 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
827 CombiningKindAttr::get(builder.
getContext(), kind));
838 DictionaryAttr dictAttr;
852 result.attributes.append(dictAttr.getValue().begin(),
853 dictAttr.getValue().end());
859 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
860 result.attributes.get(getIteratorTypesAttrName(
result.name)));
861 if (!iteratorTypes) {
863 <<
"expected " << getIteratorTypesAttrName(
result.name)
864 <<
" array attribute";
869 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
870 auto maybeIteratorType = symbolizeIteratorType(s);
871 if (!maybeIteratorType.has_value())
872 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
874 iteratorTypeAttrs.push_back(
875 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
877 result.attributes.set(getIteratorTypesAttrName(
result.name),
880 if (!
result.attributes.get(getKindAttrName(
result.name))) {
882 getKindAttrName(
result.name),
883 CombiningKindAttr::get(
result.getContext(),
884 ContractionOp::getDefaultKind()));
886 if (masksInfo.empty())
888 if (masksInfo.size() != 2)
890 "expected zero or exactly 2 vector mask operands");
891 auto lhsType = llvm::cast<VectorType>(types[0]);
892 auto rhsType = llvm::cast<VectorType>(types[1]);
894 std::array<VectorType, 2> maskTypes = {
904 auto attrNames = getTraitAttrNames();
906 traitAttrsSet.insert_range(attrNames);
908 for (
auto attr : (*this)->getAttrs()) {
909 if (attr.getName() == getIteratorTypesAttrName()) {
911 llvm::cast<ArrayAttr>(attr.getValue())
912 .getAsValueRange<IteratorTypeAttr, IteratorType>();
918 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) ->
Attribute {
919 return StringAttr::get(
getContext(), stringifyIteratorType(t));
922 attrs.emplace_back(getIteratorTypesAttrName(),
923 ArrayAttr::get(
getContext(), iteratorTypeNames));
924 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
925 attrs.push_back(attr);
928 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
929 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
930 p << getRhs() <<
", " << getAcc();
933 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
938 const std::vector<std::pair<int64_t, int64_t>> &map) {
939 for (
auto &dimPair : map) {
940 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
941 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
942 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
949 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
951 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
952 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
955 for (
auto &dimPair : contractingDimMap) {
956 lhsContractingDimSet.insert(dimPair.first);
957 rhsContractingDimSet.insert(dimPair.second);
960 llvm::make_second_range(batchDimMap));
964 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
965 if (lhsContractingDimSet.count(i) > 0)
967 expectedResultDims.push_back(lhsType.getDimSize(i));
971 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
972 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
974 expectedResultDims.push_back(rhsType.getDimSize(i));
978 if (expectedResultDims.empty()) {
980 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
981 return op.emitOpError(
"invalid accumulator/result vector shape");
984 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
985 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
986 if (!resVectorType || !accVectorType)
987 return op.emitOpError(
"invalid accumulator/result vector shape");
993 AffineMap lhsMap = op.getIndexingMapsArray()[0];
994 AffineMap rhsMap = op.getIndexingMapsArray()[1];
996 return op.emitOpError(
997 "expected all dimensions to be either a LHS or a RHS dimension");
1000 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1001 VectorType v = pair.first;
1002 auto map = pair.second;
1003 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1004 unsigned pos = map.getDimPosition(idx);
1009 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1010 return op.emitOpError(
"expected all dimensions to get an extent as "
1011 "either a LHS or a RHS dimension");
1013 AffineMap resMap = op.getIndexingMapsArray()[2];
1018 assert(llvm::all_of(expectedMap.
getResults(),
1019 llvm::IsaPred<AffineConstantExpr>) &&
1020 "expected constant extent along all dimensions.");
1022 auto expectedShape =
1024 return cast<AffineConstantExpr>(e).getValue();
1027 VectorType::get(expectedShape, resVectorType.getElementType(),
1028 resVectorType.getScalableDims());
1029 if (resVectorType != expected || accVectorType != expected)
1030 return op.emitOpError(
1031 "invalid accumulator/result vector shape, expected: ")
1037LogicalResult ContractionOp::verify() {
1038 VectorType lhsType = getLhsType();
1039 VectorType rhsType = getRhsType();
1040 Type accType = getAccType();
1041 Type resType = getResultType();
1043 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1044 if (!lhsType.getElementType().isSignlessInteger())
1045 return emitOpError(
"only supports signless integer types");
1049 if (getIndexingMapsArray().size() != 3)
1050 return emitOpError(
"expected an indexing map for each vector operand");
1055 unsigned numIterators = getIteratorTypes().getValue().size();
1056 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1057 auto index = it.index();
1058 auto map = it.value();
1059 if (map.getNumSymbols() != 0)
1061 <<
index <<
" to have no symbols";
1062 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1063 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1066 if (map.getNumDims() != numIterators)
1068 <<
index <<
" to have " << numIterators <<
" number of inputs";
1069 if (map.getNumResults() != rank)
1071 <<
index <<
" to have " << rank <<
" number of outputs";
1072 if (!map.isProjectedPermutation())
1074 <<
index <<
" to be a projected permutation of its inputs";
1077 auto contractingDimMap = getContractingDimMap();
1078 auto batchDimMap = getBatchDimMap();
1081 if (contractingDimMap.empty())
1082 return emitOpError(
"expected at least one contracting dimension pair");
1085 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1086 return emitOpError(
"invalid contracting dimension map");
1090 return emitOpError(
"invalid batch dimension map");
1094 contractingDimMap, batchDimMap)))
1097 if (!getKindAttr()) {
1098 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1099 "'vector.kind<add>')");
1103 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1104 auto elementType = vectorType ? vectorType.getElementType() : resType;
1106 return emitOpError(
"unsupported contraction type");
1109 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1116Type ContractionOp::getExpectedMaskType() {
1117 auto indexingMaps = this->getIndexingMapsArray();
1120 VectorType lhsType = this->getLhsType();
1121 VectorType rhsType = this->getRhsType();
1123 unsigned numVecDims = lhsIdxMap.
getNumDims();
1129 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1132 lhsType.getScalableDims()[dimIdx];
1134 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1137 rhsType.getScalableDims()[dimIdx];
1140 assert(ShapedType::isStaticShape(maskShape) &&
1141 "Mask shape couldn't be computed");
1143 return VectorType::get(maskShape,
1144 IntegerType::get(lhsType.getContext(), 1),
1145 maskShapeScalableDims);
1150 getIteratorTypesAttrName(), getKindAttrName()};
1160static std::vector<std::pair<int64_t, int64_t>>
1162 IteratorType targetIteratorType,
MLIRContext *context) {
1163 std::vector<std::pair<int64_t, int64_t>> dimMap;
1164 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1165 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1166 if (iteratorType != targetIteratorType)
1172 if (lhsDim >= 0 && rhsDim >= 0)
1173 dimMap.emplace_back(lhsDim, rhsDim);
1178void ContractionOp::getIterationBounds(
1180 auto lhsShape = getLhsType().getShape();
1181 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1183 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1186 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1187 if (iteratorType == IteratorType::reduction) {
1190 assert(lhsDimIndex >= 0);
1191 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1196 assert(resDimIndex >= 0);
1197 assert(resVectorType !=
nullptr);
1198 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1202void ContractionOp::getIterationIndexMap(
1204 unsigned numMaps = getIndexingMapsArray().size();
1205 iterationIndexMap.resize(numMaps);
1206 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1207 auto index = it.index();
1208 auto map = it.value();
1209 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1210 auto dim = cast<AffineDimExpr>(map.getResult(i));
1211 iterationIndexMap[
index][dim.getPosition()] = i;
1216std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1218 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1222std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1224 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1228std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1230 getIterationBounds(
shape);
1252template <
typename AddOpType>
1258 auto canonicalize = [&](
Value maybeContraction,
1259 Value otherOperand) -> vector::ContractionOp {
1260 vector::ContractionOp contractionOp =
1261 dyn_cast_or_null<vector::ContractionOp>(
1264 return vector::ContractionOp();
1265 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1266 contractionOp.getAcc().getDefiningOp())) {
1267 if (maybeZero.getValue() ==
1268 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1270 bvm.
map(contractionOp.getAcc(), otherOperand);
1271 auto newContraction =
1272 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1273 rewriter.
replaceOp(addOp, newContraction.getResult());
1274 return newContraction;
1277 return vector::ContractionOp();
1280 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1281 vector::ContractionOp
contract = canonicalize(a,
b);
1306 setResultRanges(getResult(), argRanges.front());
1311 auto vectorTy = cast<VectorType>(source.
getType());
1336 build(builder,
result, source, dynamicPos,
1341ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1342 ExtractOp::Adaptor adaptor,
1344 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1345 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1346 vectorType.getRank()) {
1347 inferredReturnTypes.push_back(vectorType.getElementType());
1349 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1350 vectorType.getRank());
1351 inferredReturnTypes.push_back(VectorType::get(
1352 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1353 vectorType.getScalableDims().drop_front(n)));
1358LogicalResult vector::ExtractOp::verify() {
1359 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1360 if (resTy.getRank() == 0)
1362 "expected a scalar instead of a 0-d vector as the result type");
1365 auto dynamicMarkersCount =
1366 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1367 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1369 "mismatch between dynamic and static positions (kDynamic marker but no "
1370 "corresponding dynamic position) -- this can only happen due to an "
1371 "incorrect fold/rewrite");
1372 auto position = getMixedPosition();
1373 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1375 "expected position attribute of rank no greater than vector rank");
1376 for (
auto [idx, pos] : llvm::enumerate(position)) {
1377 if (
auto attr = dyn_cast<Attribute>(pos)) {
1378 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1380 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1381 return emitOpError(
"expected position attribute #")
1383 <<
" to be a non-negative integer smaller than the "
1384 "corresponding vector dimension or poison (-1)";
1391template <
typename IntType>
1393 return llvm::map_to_vector<4>(
1394 arrayAttr.getAsRange<IntegerAttr>(),
1395 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1401 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1405 if (extractOp.hasDynamicPosition())
1409 ExtractOp currentOp = extractOp;
1411 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1412 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1415 if (currentOp.hasDynamicPosition())
1418 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1420 extractOp.setOperand(0, currentOp.getSource());
1423 std::reverse(globalPosition.begin(), globalPosition.end());
1424 extractOp.setStaticPosition(globalPosition);
1436class ExtractFromInsertTransposeChainState {
1438 ExtractFromInsertTransposeChainState(ExtractOp e);
1447 template <
typename ContainerA,
typename ContainerB>
1448 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1449 return a.size() <=
b.size() &&
1450 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1457 template <
typename ContainerA,
typename ContainerB>
1458 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1459 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1460 if (elemA < 0 || elemB < 0)
1471 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1475 void updateStateForNextIteration(Value v) {
1482 LogicalResult handleTransposeOp();
1485 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1500 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1505 Value tryToFoldExtractOpInPlace(Value source);
1507 ExtractOp extractOp;
1509 int64_t extractedRank;
1511 InsertOp nextInsertOp;
1512 TransposeOp nextTransposeOp;
1522 SmallVector<int64_t> sentinels;
1523 SmallVector<int64_t> extractPosition;
1527ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1529 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1530 extractedRank(extractOp.getNumIndices()) {
1531 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1532 sentinels.reserve(vectorRank - extractedRank);
1533 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1534 sentinels.push_back(-(i + 1));
1536 extractOp.getStaticPosition().end());
1542LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1544 if (extractOp.hasDynamicPosition())
1547 if (!nextTransposeOp)
1550 nextTransposeOp.getPermutation(), extractOp.getContext()));
1557ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1560 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1563 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1564 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1567 res = nextInsertOp.getValueToStore();
1576ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1578 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1581 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1591 res = nextInsertOp.getValueToStore();
1599Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1602 if (extractOp.hasDynamicPosition())
1606 bool nothingToFold = (source == extractOp.getSource());
1607 if (nothingToFold || !canFold())
1611 OpBuilder
b(extractOp.getContext());
1612 extractOp.setStaticPosition(
1614 extractOp.getSourceMutable().assign(source);
1615 return extractOp.getResult();
1619Value ExtractFromInsertTransposeChainState::fold() {
1621 if (extractOp.hasDynamicPosition())
1624 Value valueToExtractFrom = extractOp.getSource();
1625 updateStateForNextIteration(valueToExtractFrom);
1626 while (nextInsertOp || nextTransposeOp) {
1629 if (succeeded(handleTransposeOp())) {
1630 valueToExtractFrom = nextTransposeOp.getVector();
1631 updateStateForNextIteration(valueToExtractFrom);
1637 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1642 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1643 return tryToFoldExtractOpInPlace(
result);
1647 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1653 valueToExtractFrom = nextInsertOp.getDest();
1654 updateStateForNextIteration(valueToExtractFrom);
1657 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1662 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1663 auto vecType = dyn_cast<VectorType>(type);
1664 return vecType && vecType.getRank() == 0;
1674 if (isa<BroadcastOp>(op))
1677 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1685 VectorType srcType = shapeCast.getSourceVectorType();
1687 uint64_t srcRank = srcType.getRank();
1689 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1715 Operation *defOp = extractOp.getSource().getDefiningOp();
1722 if (extractOp.getType() == input.
getType())
1728 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1729 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1730 unsigned inputRank = inputType ? inputType.getRank() : 0;
1731 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1732 unsigned extractRank = extractType ? extractType.getRank() : 0;
1735 if (extractRank > inputRank)
1739 assert(inputType &&
"input must be a vector type because of previous checks");
1748 extractType.getShape() != inputShape.take_back(extractRank))
1753 unsigned deltaOverall = inputRank - extractRank;
1754 unsigned deltaBroadcast = broadcastRank - inputRank;
1758 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1759 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1762 extractOp->setOperands(
1763 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1764 extractOp.setStaticPosition(staticPos);
1765 return extractOp.getResult();
1781 if (extractOp.hasDynamicPosition())
1784 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1789 if (shuffleOp.getResultVectorType().getRank() != 1)
1792 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1793 auto shuffleMask = shuffleOp.getMask();
1794 int64_t extractIdx = extractOp.getStaticPosition()[0];
1795 int64_t shuffleIdx = shuffleMask[extractIdx];
1798 if (shuffleIdx < inputVecSize) {
1799 extractOp.setOperand(0, shuffleOp.getV1());
1800 extractOp.setStaticPosition({shuffleIdx});
1802 extractOp.setOperand(0, shuffleOp.getV2());
1803 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1806 return extractOp.getResult();
1812 if (extractOp.hasDynamicPosition())
1815 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1820 auto getDimReverse = [](VectorType type,
int64_t n) {
1821 return type.getShape().take_back(n + 1).front();
1824 llvm::isa<VectorType>(extractOp.getType())
1825 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1827 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1829 if (destinationRank > 0) {
1830 auto destinationType =
1831 llvm::cast<VectorType>(extractOp.getResult().getType());
1832 for (
int64_t i = 0; i < destinationRank; i++) {
1836 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1837 getDimReverse(destinationType, i))
1844 std::reverse(extractedPos.begin(), extractedPos.end());
1847 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1848 strides.push_back(stride);
1850 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1858 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1860 for (
int64_t i = 0; i < numDimension; i++) {
1861 newStrides.push_back(stride);
1863 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1865 std::reverse(newStrides.begin(), newStrides.end());
1869 extractOp.setStaticPosition(newPosition);
1870 extractOp.setOperand(0, shapeCastOp.getSource());
1871 return extractOp.getResult();
1877 if (extractOp.hasDynamicPosition())
1880 auto extractStridedSliceOp =
1881 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1882 if (!extractStridedSliceOp)
1891 if (extractStridedSliceOp.hasNonUnitStrides())
1897 while (!sliceOffsets.empty()) {
1898 size_t lastOffset = sliceOffsets.size() - 1;
1899 if (sliceOffsets.back() != 0 ||
1900 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1901 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1903 sliceOffsets.pop_back();
1905 unsigned destinationRank = 0;
1906 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1907 destinationRank = vecType.getRank();
1910 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1911 sliceOffsets.size())
1915 assert(extractedPos.size() >= sliceOffsets.size());
1916 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1917 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1918 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1922 extractOp.setStaticPosition(extractedPos);
1923 return extractOp.getResult();
1929 if (extractOp.hasDynamicPosition())
1933 llvm::isa<VectorType>(extractOp.getType())
1934 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1936 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1946 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1947 insertOp.getSourceVectorType().getRank();
1948 if (destinationRank > insertOp.getSourceVectorType().getRank())
1953 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1954 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1957 bool disjoint =
false;
1959 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1960 int64_t start = insertOffsets[dim];
1962 (dim < insertRankDiff)
1964 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1966 int64_t offset = extractOffsets[dim];
1968 if (start <= offset && offset < end) {
1969 if (dim >= insertRankDiff)
1970 offsetDiffs.push_back(offset - start);
1981 insertOp.getSourceVectorType().getRank() - destinationRank;
1982 for (
int64_t i = 0; i < destinationRank; i++) {
1983 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1984 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1988 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1991 extractOp.setStaticPosition(offsetDiffs);
1992 return extractOp.getResult();
1996 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2009 if (extractOp.hasDynamicPosition())
2013 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2014 if (!fromElementsOp)
2018 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2019 if (vecType.isScalable())
2023 int64_t rank = vecType.getRank();
2025 if (extractOp.getType() != vecType.getElementType())
2028 "unexpected number of indices");
2033 for (
int i = rank - 1; i >= 0; --i) {
2034 flatIndex +=
indices[i] * stride;
2035 stride *= vecType.getDimSize(i);
2037 return fromElementsOp.getElements()[flatIndex];
2042template <
typename OpType,
typename AdaptorType>
2045 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2046 OperandRange dynamicPosition = op.getDynamicPosition();
2049 if constexpr (std::is_same_v<OpType, ExtractOp>)
2050 vectorShape = op.getSourceVectorType().getShape();
2055 if (!dynamicPosition.size())
2062 bool opChange =
false;
2063 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2064 if (ShapedType::isStatic(staticPosition[i]))
2068 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2069 int64_t value = attr.getInt();
2073 staticPosition[i] = attr.getInt();
2078 operands.push_back(position);
2082 op.setStaticPosition(staticPosition);
2083 op.getOperation()->setOperands(operands);
2085 return op.getResult();
2095 if (!is_contained(staticPos, poisonVal))
2098 return ub::PoisonAttr::get(context);
2112 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2117 if (denseAttr.isSplat()) {
2119 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2124 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2125 if (vecTy.isScalable())
2128 if (extractOp.hasDynamicPosition()) {
2143 copy(extractOp.getStaticPosition(), completePositions.begin());
2146 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2149 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2151 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2154 newAttr = *denseValuesBegin;
2160OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2164 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2171 SmallVector<Value> operands = {getSource()};
2175 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2181 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2196 return inplaceFolded;
2202class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2206 LogicalResult matchAndRewrite(ExtractOp extractOp,
2207 PatternRewriter &rewriter)
const override {
2210 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2216 BroadcastableToResult::Success)
2225class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2229 LogicalResult matchAndRewrite(ExtractOp extractOp,
2230 PatternRewriter &rewriter)
const override {
2232 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2236 VectorType extractedMaskType =
2237 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2239 if (!extractedMaskType)
2242 auto maskOperands = createMaskOp.getOperands();
2243 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2244 VectorType maskType = createMaskOp.getVectorType();
2246 bool containsUnknownDims =
false;
2249 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2251 int64_t pos = extractOpPos[dimIdx];
2252 Value operand = maskOperands[dimIdx];
2253 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2256 containsUnknownDims =
true;
2260 int64_t createMaskBound =
2261 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2263 if (pos != ShapedType::kDynamic) {
2266 allFalse |= pos >= createMaskBound;
2267 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2271 containsUnknownDims =
true;
2278 }
else if (!containsUnknownDims) {
2280 extractOp, extractedMaskType,
2281 maskOperands.drop_front(extractOpPos.size()));
2290class ExtractOpFromConstantMask final :
public OpRewritePattern<ExtractOp> {
2294 LogicalResult matchAndRewrite(ExtractOp extractOp,
2295 PatternRewriter &rewriter)
const override {
2296 auto constantMaskOp =
2297 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2298 if (!constantMaskOp)
2301 Type resultType = extractOp.getResult().getType();
2302 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2304 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2305 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2307 VectorType maskType = constantMaskOp.getVectorType();
2310 for (
size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2311 int64_t pos = extractOpPos[dimIdx];
2312 if (pos == ShapedType::kDynamic) {
2315 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2324 if (pos >= maskDimSizes[dimIdx]) {
2325 if (extractedMaskType) {
2337 if (extractedMaskType) {
2341 extractOp, extractedMaskType,
2342 maskDimSizes.drop_front(extractOpPos.size()));
2355LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2356 PatternRewriter &rewriter) {
2357 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2361 VectorType sourceType = castOp.getSourceVectorType();
2362 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2366 if (sourceType.getNumElements() != targetType.getNumElements())
2370 castOp.getSource());
2380LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2381 PatternRewriter &rewriter) {
2383 if (extractOp.hasDynamicPosition())
2387 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2392 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2393 if (!fromElementsOp)
2395 VectorType inputType = fromElementsOp.getType();
2398 if (resultType.isScalable() || inputType.isScalable())
2403 SmallVector<int64_t> firstElementPos =
2404 llvm::to_vector(extractOp.getStaticPosition());
2405 firstElementPos.append(resultType.getRank(), 0);
2408 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2409 flatIndex += firstElementPos[i] * stride;
2410 stride *= inputType.getDimSize(i);
2415 extractOp, resultType,
2416 fromElementsOp.getElements().slice(flatIndex,
2417 resultType.getNumElements()));
2429struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2431 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2432 PatternRewriter &rewriter)
const override {
2433 VectorType sourceType = extractOp.getSourceVectorType();
2434 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2438 if (sourceType.getNumElements() != outType.getNumElements())
2440 extractOp,
"extract to vector with fewer elements");
2444 if (llvm::any_of(extractOp.getMixedPosition(),
2445 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2447 "leaving for extract poison folder");
2450 extractOp.getSource());
2458void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2459 MLIRContext *context) {
2460 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2461 ExtractOpFromConstantMask, ExtractToShapeCast>(context);
2462 results.
add(foldExtractFromShapeCastToShapeCast);
2463 results.
add(foldExtractFromFromElements);
2468 for (
auto attr : arrayAttr)
2469 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2476std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2487 if (operands.empty())
2490 return llvm::all_of(operands, [&](
Value operand) {
2492 return currentDef == defOp;
2510 auto fromElementsOp =
2511 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2512 if (!fromElementsOp)
2515 llvm::append_range(results, fromElementsOp.getElements());
2532 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2536 if (isa<VectorType>(bcastOp.getSource().getType()))
2539 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2541 Value scalar = bcastOp.getSource();
2542 results.assign(resultVecType.getNumElements(), scalar);
2546LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2547 SmallVectorImpl<OpFoldResult> &results) {
2552 if (
auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2553 setOperand(shapeCast.getSource());
2561ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2562 ToElementsOp::Adaptor adaptor,
2563 SmallVectorImpl<Type> &inferredReturnTypes) {
2564 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2565 Type elType = vecType.getElementType();
2566 inferredReturnTypes.append(vecType.getNumElements(), elType);
2588 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2593 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2597 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2602 int64_t dstRank = dstShape.size();
2603 int64_t srcRank = srcShape.size();
2606 auto srcElems = vector::ToElementsOp::create(
2607 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2609 int64_t dstCount = llvm::product_of(dstShape);
2612 replacements.reserve(dstCount);
2637 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2640 for (
int64_t k = 0; k < srcRank; ++k)
2641 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2644 replacements.push_back(srcElems.getResult(srcLin));
2647 rewriter.
replaceOp(toElementsOp, replacements);
2652void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2653 MLIRContext *context) {
2654 results.
add<ToElementsOfBroadcast>(context);
2674 OperandRange fromElemsOperands = fromElementsOp.getElements();
2675 if (fromElemsOperands.empty())
2678 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2686 Value toElementsInput = toElementsOp.getSource();
2687 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2688 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2689 return toElementsInput;
2709 if (llvm::any_of(elements, [](
Attribute attr) {
2715 auto destVecType = fromElementsOp.getDest().getType();
2716 auto destEltType = destVecType.getElementType();
2717 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2722 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2729OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2746 if (!llvm::all_equal(fromElementsOp.getElements()))
2749 fromElementsOp, fromElementsOp.getType(),
2750 fromElementsOp.getElements().front());
2778 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2782 if (fromElements.getType().getNumElements() == 1)
2793 for (
auto [insertIndex, element] :
2794 llvm::enumerate(fromElements.getElements())) {
2797 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2800 "element not from vector.extract");
2805 if (insertIndex == 0) {
2806 source = extractOp.getSource();
2807 }
else if (extractOp.getSource() != source) {
2809 "element from different vector");
2813 int64_t rank = position.size();
2814 assert(rank == source.getType().getRank() &&
2815 "scalar extract must have full rank position");
2826 if (insertIndex == 0) {
2827 const int64_t numElms = fromElements.getType().getNumElements();
2830 while (
index > 0 && position[
index - 1] == 0 &&
2831 numSuffixElms < numElms) {
2832 numSuffixElms *= source.getType().getDimSize(
index - 1);
2835 if (numSuffixElms != numElms) {
2837 fromElements,
"elements do not form a suffix of source");
2839 expectedPosition = llvm::to_vector(position);
2840 combinedPosition = position.drop_back(rank -
index);
2844 else if (expectedPosition != position) {
2846 fromElements,
"elements not in ascending order (static order)");
2848 increment(expectedPosition, source.getType().getShape());
2851 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2852 fromElements.getLoc(), source, combinedPosition);
2855 fromElements, fromElements.getType(), extracted);
2863 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2882void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2884 setResultRanges(getResult(), argRanges.front());
2887std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2888 return llvm::to_vector<4>(getResultVectorType().
getShape());
2893static llvm::SetVector<int64_t>
2896 int64_t rankDiff = dstShape.size() - srcShape.size();
2899 for (
auto [s1, s2] :
2900 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2902 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2910llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2912 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2915 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2931Value BroadcastOp::createOrFoldBroadcastOp(
2932 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
2933 const llvm::SetVector<int64_t> &broadcastedDims) {
2934 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2937 SmallVector<int64_t> checkShape;
2938 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2939 if (broadcastedDims.contains(i))
2941 checkShape.push_back(dstShape[i]);
2943 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2944 "ill-formed broadcastedDims contains values not confined to "
2947 Location loc = value.
getLoc();
2949 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2950 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2953 if (!srcVectorType) {
2954 assert(checkShape.empty() &&
2955 "ill-formed createOrFoldBroadcastOp arguments");
2956 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2959 assert(srcVectorType.getShape().equals(checkShape) &&
2960 "ill-formed createOrFoldBroadcastOp arguments");
2970 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2971 broadcastShape.reserve(dstShape.size());
2987 int64_t nextSrcShapeDim = broadcastedDims.size();
2988 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2989 if (broadcastedDims.contains(i)) {
2994 broadcastShape.push_back(dstShape[i]);
2995 permutation[i] = broadcastShape.size() - 1;
3001 permutation[i] = nextSrcShapeDim++;
3005 llvm::append_range(broadcastShape, srcVectorType.getShape());
3010 "unexpected \"dim-1\" broadcast");
3012 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3014 vector::BroadcastableToResult::Success &&
3015 "must be broadcastable");
3016 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3019 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3020 if (permutation[i] != i)
3021 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3027 Type srcType, VectorType dstVectorType,
3028 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3030 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3034 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3038 int64_t srcRank = srcVectorType.getRank();
3039 int64_t dstRank = dstVectorType.getRank();
3040 if (srcRank > dstRank)
3044 int64_t lead = dstRank - srcRank;
3045 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3048 bool foundMismatchingDims =
false;
3051 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3052 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3053 if (srcDim != 1 && srcDim != dstDim)
3054 foundMismatchingDims =
true;
3057 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3058 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3059 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3062 (srcDimScalableFlag != dstDimScalableFlag &&
3063 (srcDim != 1 || srcDimScalableFlag)))
3064 foundMismatchingDims =
true;
3066 if (foundMismatchingDims) {
3067 if (mismatchingDims !=
nullptr) {
3068 mismatchingDims->first.dim = srcDim;
3069 mismatchingDims->first.isScalable = srcDimScalableFlag;
3071 mismatchingDims->second.dim = dstDim;
3072 mismatchingDims->second.isScalable = dstDimScalableFlag;
3081LogicalResult BroadcastOp::verify() {
3082 std::pair<VectorDim, VectorDim> mismatchingDims;
3084 getSourceType(), getResultVectorType(), &mismatchingDims);
3088 return emitOpError(
"source rank higher than destination rank");
3091 << (mismatchingDims.first.isScalable ?
"[" :
"")
3092 << mismatchingDims.first.dim
3093 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
3094 << (mismatchingDims.second.isScalable ?
"[" :
"")
3095 << mismatchingDims.second.dim
3096 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3099 return emitOpError(
"source type is not a vector");
3100 llvm_unreachable(
"unexpected vector.broadcast op error");
3107 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3111 VectorType srcType = srcShapeCast.getSourceVectorType();
3112 VectorType destType = broadcastOp.getResultVectorType();
3120 srcShapeCast.getResultVectorType().getShape();
3123 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3124 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3125 shapecastShape.take_back(numTrailingDims)))
3128 assert(all_of(srcShape.drop_back(numTrailingDims),
3129 [](
int64_t E) { return E == 1; }) &&
3130 all_of(shapecastShape.drop_back(numTrailingDims),
3131 [](
int64_t E) { return E == 1; }) &&
3132 "ill-formed shape_cast");
3134 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3138OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3139 if (getSourceType() == getResultVectorType())
3144 if (!adaptor.getSource())
3146 auto vectorType = getResultVectorType();
3147 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3148 if (vectorType.getElementType() != attr.getType())
3152 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3153 if (vectorType.getElementType() != attr.getType())
3157 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3167struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3170 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3171 PatternRewriter &rewriter)
const override {
3172 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3176 broadcastOp.getResultVectorType(),
3177 srcBroadcast.getSource());
3190struct BroadcastToShapeCast final
3191 :
public OpRewritePattern<vector::BroadcastOp> {
3193 LogicalResult matchAndRewrite(vector::BroadcastOp
broadcast,
3194 PatternRewriter &rewriter)
const override {
3196 auto sourceType = dyn_cast<VectorType>(
broadcast.getSourceType());
3199 broadcast,
"source is a scalar, shape_cast doesn't support scalar");
3203 if (sourceType.getNumElements() != outType.getNumElements()) {
3205 broadcast,
"broadcast to a greater number of elements");
3215void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3216 MLIRContext *context) {
3217 results.
add<BroadcastFolder, BroadcastToShapeCast>(context);
3224LogicalResult ShuffleOp::verify() {
3225 VectorType resultType = getResultVectorType();
3226 VectorType v1Type = getV1VectorType();
3227 VectorType v2Type = getV2VectorType();
3229 int64_t resRank = resultType.getRank();
3230 int64_t v1Rank = v1Type.getRank();
3231 int64_t v2Rank = v2Type.getRank();
3232 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3233 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3234 if (!wellFormed0DCase && !wellFormedNDCase)
3238 for (int64_t r = 1; r < v1Rank; ++r) {
3239 int64_t resDim = resultType.getDimSize(r);
3240 int64_t v1Dim = v1Type.getDimSize(r);
3241 int64_t v2Dim = v2Type.getDimSize(r);
3242 if (resDim != v1Dim || v1Dim != v2Dim)
3246 ArrayRef<int64_t> mask = getMask();
3247 int64_t maskLength = mask.size();
3248 if (maskLength <= 0)
3250 if (maskLength != resultType.getDimSize(0))
3253 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3254 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3255 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3257 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3263ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3264 ShuffleOp::Adaptor adaptor,
3265 SmallVectorImpl<Type> &inferredReturnTypes) {
3266 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3270 auto v1Rank = v1Type.getRank();
3273 SmallVector<int64_t, 4> shape;
3274 shape.reserve(v1Rank);
3275 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3278 llvm::append_range(shape, v1Type.getShape().drop_front());
3279 inferredReturnTypes.push_back(
3280 VectorType::get(shape, v1Type.getElementType()));
3284template <
typename T>
3287 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3288 return value == expected++;
3292OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3293 auto v1Type = getV1VectorType();
3294 auto v2Type = getV2VectorType();
3296 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3297 "Vector shuffle does not support scalable vectors");
3301 if (v1Type.getRank() == 0)
3305 auto mask = getMask();
3312 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3313 if (!v1Attr || !v2Attr)
3319 if (isV1Poison && isV2Poison)
3324 if (v1Type.getRank() != 1)
3330 SmallVector<Attribute> v1Elements, v2Elements;
3331 Attribute poisonElement;
3333 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3336 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3337 poisonElement = v2Elements[0];
3340 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3343 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3344 poisonElement = v1Elements[0];
3347 SmallVector<Attribute> results;
3348 int64_t v1Size = v1Type.getDimSize(0);
3349 for (int64_t maskIdx : mask) {
3350 Attribute indexedElm;
3352 if (maskIdx == ShuffleOp::kPoisonIndex) {
3353 indexedElm = poisonElement;
3355 if (maskIdx < v1Size)
3356 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3358 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3361 results.push_back(indexedElm);
3371struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3374 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3375 PatternRewriter &rewriter)
const override {
3376 VectorType v1VectorType = shuffleOp.getV1VectorType();
3377 ArrayRef<int64_t> mask = shuffleOp.getMask();
3378 if (v1VectorType.getRank() > 0)
3380 if (mask.size() != 1)
3382 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3400static Value getScalarSplatSource(Value value) {
3406 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3413 if (isa<VectorType>(
broadcast.getSourceType()))
3421class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3425 LogicalResult matchAndRewrite(ShuffleOp op,
3426 PatternRewriter &rewriter)
const override {
3427 Value splat = getScalarSplatSource(op.getV1());
3428 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3438class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3442 LogicalResult matchAndRewrite(ShuffleOp op,
3443 PatternRewriter &rewriter)
const override {
3444 VectorType resultType = op.getResultVectorType();
3445 if (resultType.isScalable())
3447 op,
"ShuffleOp can't represent a scalable interleave");
3449 if (resultType.getRank() != 1)
3451 op,
"ShuffleOp can't represent an n-D interleave");
3453 VectorType sourceType = op.getV1VectorType();
3454 if (sourceType != op.getV2VectorType() ||
3455 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3457 op,
"ShuffleOp types don't match an interleave");
3460 ArrayRef<int64_t> shuffleMask = op.getMask();
3461 int64_t resultVectorSize = resultType.getNumElements();
3462 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3463 int64_t maskValueA = shuffleMask[i * 2];
3464 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3465 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3467 "ShuffleOp mask not interleaving");
3477void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3478 MLIRContext *context) {
3479 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3487void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3489 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3492void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3493 Value source, Value dest) {
3494 auto vectorTy = cast<VectorType>(dest.
getType());
3495 build(builder,
result, source, dest,
3496 SmallVector<int64_t>(vectorTy.getRank(), 0));
3499void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3500 Value source, Value dest, int64_t position) {
3501 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3504void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3505 Value source, Value dest, OpFoldResult position) {
3506 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3509void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3510 Value source, Value dest,
3511 ArrayRef<int64_t> position) {
3512 SmallVector<OpFoldResult> posVals;
3513 posVals.reserve(position.size());
3514 llvm::transform(position, std::back_inserter(posVals),
3516 build(builder,
result, source, dest, posVals);
3519void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3520 Value source, Value dest,
3521 ArrayRef<OpFoldResult> position) {
3522 SmallVector<int64_t> staticPos;
3523 SmallVector<Value> dynamicPos;
3525 build(builder,
result, source, dest, dynamicPos,
3529LogicalResult InsertOp::verify() {
3530 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3531 if (srcTy.getRank() == 0)
3533 "expected a scalar instead of a 0-d vector as the source operand");
3535 SmallVector<OpFoldResult> position = getMixedPosition();
3536 auto destVectorType = getDestVectorType();
3537 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3539 "expected position attribute of rank no greater than dest vector rank");
3540 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3541 if (srcVectorType &&
3542 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3543 static_cast<unsigned>(destVectorType.getRank())))
3544 return emitOpError(
"expected position attribute rank + source rank to "
3545 "match dest vector rank");
3546 if (!srcVectorType &&
3547 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3549 "expected position attribute rank to match the dest vector rank");
3550 for (
auto [idx, pos] : llvm::enumerate(position)) {
3551 if (
auto attr = dyn_cast<Attribute>(pos)) {
3552 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3554 destVectorType.getDimSize(idx))) {
3555 return emitOpError(
"expected position attribute #")
3557 <<
" to be a non-negative integer smaller than the "
3559 "dest vector dimension";
3572 assert(positions.size() <= completePositions.size() &&
3573 "positions size must be less than or equal to destTy rank");
3574 copy(positions, completePositions.begin());
3582class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3586 LogicalResult matchAndRewrite(InsertOp insertOp,
3587 PatternRewriter &rewriter)
const override {
3589 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3590 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3591 srcVecType.getNumElements())
3594 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3600class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3604 LogicalResult matchAndRewrite(InsertOp op,
3605 PatternRewriter &rewriter)
const override {
3607 Value splat = getScalarSplatSource(op.getValueToStore());
3608 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3636class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3639 LogicalResult matchAndRewrite(InsertOp op,
3640 PatternRewriter &rewriter)
const override {
3642 VectorType destTy = op.getDestVectorType();
3643 if (destTy.isScalable())
3646 for (Operation *user : op.getResult().getUsers())
3647 if (
auto insertOp = dyn_cast<InsertOp>(user))
3648 if (insertOp.getDest() == op.getResult())
3651 InsertOp currentOp = op;
3652 SmallVector<InsertOp> chainInsertOps;
3655 if (currentOp.hasDynamicPosition())
3658 chainInsertOps.push_back(currentOp);
3659 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3662 if (currentOp && !currentOp->hasOneUse())
3666 int64_t vectorSize = destTy.getNumElements();
3667 int64_t initializedCount = 0;
3668 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3669 SmallVector<int64_t> pendingInsertPos;
3670 SmallVector<int64_t> pendingInsertSize;
3671 SmallVector<Value> pendingInsertValues;
3673 for (
auto insertOp : chainInsertOps) {
3675 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3679 int64_t insertBeginPosition =
3684 int64_t insertSize = 1;
3685 if (
auto srcVectorType =
3686 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3687 insertSize = srcVectorType.getNumElements();
3689 assert(insertBeginPosition + insertSize <= vectorSize &&
3690 "insert would overflow the vector");
3692 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3693 insertBeginPosition + insertSize)) {
3694 if (initializedDestIdxs[index])
3696 initializedDestIdxs[index] =
true;
3702 pendingInsertPos.push_back(insertBeginPosition);
3703 pendingInsertSize.push_back(insertSize);
3704 pendingInsertValues.push_back(insertOp.getValueToStore());
3706 if (initializedCount == vectorSize)
3711 if (initializedCount != vectorSize)
3714 SmallVector<Value> elements(vectorSize);
3715 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3716 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3717 pendingInsertValues))) {
3718 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3720 if (!srcVectorType) {
3721 elements[insertBeginPosition] = valueToStore;
3725 Repeated<Type> elementToInsertTypes(insertSize,
3726 srcVectorType.getElementType());
3728 auto elementsToInsert = vector::ToElementsOp::create(
3729 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3730 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3731 elements[insertBeginPosition + linearIdx] =
3732 elementsToInsert.getResult(linearIdx);
3746 int64_t maxVectorSizeFoldThreshold) {
3747 if (insertOp.hasDynamicPosition())
3750 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3758 VectorType destTy = insertOp.getDestVectorType();
3759 if (destTy.isScalable())
3763 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3764 !insertOp->hasOneUse())
3769 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3776 Type destEltType = destTy.getElementType();
3780 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3781 for (
auto value : denseSource.getValues<
Attribute>())
3787 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3788 copy(insertedValues, allValues.begin() + insertBeginPosition);
3797 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3801 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3804 insertOp.
setOperand(1, destInsert.getDest());
3805 return insertOp.getResult();
3808void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3809 MLIRContext *context) {
3810 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3811 InsertChainFullyInitialized>(context);
3814OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3817 constexpr int64_t vectorSizeFoldThreshold = 256;
3821 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3822 return getValueToStore();
3826 SmallVector<Value> operands = {getValueToStore(), getDest()};
3832 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3835 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3836 vectorSizeFoldThreshold)) {
3840 return inplaceFolded;
3847void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
3848 Value source, Value dest,
3849 ArrayRef<int64_t> offsets,
3850 ArrayRef<int64_t> strides) {
3851 result.addOperands({source, dest});
3855 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
3857 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
3862template <
typename OpType>
3866 StringRef attrName) {
3867 if (arrayAttr.size() >
shape.size())
3868 return op.emitOpError(
"expected ")
3869 << attrName <<
" attribute of rank no greater than vector rank";
3876template <
typename OpType>
3880 bool halfOpen =
true) {
3881 for (
auto attr : arrayAttr) {
3882 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3886 if (val < min || val >= upper)
3887 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3888 <<
min <<
", " << upper <<
")";
3896template <
typename OpType>
3901 for (
auto [
index, attrDimPair] :
3902 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
3903 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3907 if (val < min || val >=
max)
3908 return op.emitOpError(
"expected ")
3909 << attrName <<
" dimension " <<
index <<
" to be confined to ["
3910 <<
min <<
", " <<
max <<
")";
3920template <
typename OpType>
3925 assert(arrayAttr1.size() <=
shape.size());
3926 assert(arrayAttr2.size() <=
shape.size());
3927 for (
auto [
index, it] :
3928 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
3929 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3930 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3934 if (val1 + val2 < 0 || val1 + val2 >=
max)
3935 return op.emitOpError(
"expected sum(")
3936 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
3937 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3945 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3947 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3950LogicalResult InsertStridedSliceOp::verify() {
3951 auto sourceVectorType = getSourceVectorType();
3952 auto destVectorType = getDestVectorType();
3953 auto offsets = getOffsetsAttr();
3954 auto strides = getStridesAttr();
3955 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3957 "expected offsets of same size as destination vector rank");
3958 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3959 return emitOpError(
"expected strides of same size as source vector rank");
3960 if (sourceVectorType.getRank() > destVectorType.getRank())
3962 "expected source rank to be no greater than destination rank");
3964 auto sourceShape = sourceVectorType.getShape();
3965 auto destShape = destVectorType.getShape();
3966 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3967 destShape.size() - sourceShape.size(), 0);
3968 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3969 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3970 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3979 offName,
"source vector shape",
3983 unsigned rankDiff = destShape.size() - sourceShape.size();
3984 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3985 if (sourceVectorType.getScalableDims()[idx] !=
3986 destVectorType.getScalableDims()[idx + rankDiff]) {
3987 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3990 if (sourceVectorType.getScalableDims()[idx]) {
3991 auto sourceSize = sourceShape[idx];
3992 auto destSize = destShape[idx + rankDiff];
3993 if (sourceSize != destSize) {
3996 << (
" to match the corresponding base size from the input "
3998 << sourceSize << (
" vs ") << destSize << (
")");
4008class FoldInsertStridedSliceSplat final
4009 :
public OpRewritePattern<InsertStridedSliceOp> {
4013 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4014 PatternRewriter &rewriter)
const override {
4016 auto dst = insertStridedSliceOp.getDest();
4017 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4018 if (!splat || getScalarSplatSource(dst) != splat)
4021 rewriter.
replaceOp(insertStridedSliceOp, dst);
4028class FoldInsertStridedSliceOfExtract final
4029 :
public OpRewritePattern<InsertStridedSliceOp> {
4033 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4034 PatternRewriter &rewriter)
const override {
4035 auto extractStridedSliceOp =
4036 insertStridedSliceOp.getValueToStore()
4037 .getDefiningOp<vector::ExtractStridedSliceOp>();
4039 if (!extractStridedSliceOp)
4042 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4046 if (extractStridedSliceOp.getStrides() !=
4047 insertStridedSliceOp.getStrides() ||
4048 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4051 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4058class InsertStridedSliceConstantFolder final
4059 :
public OpRewritePattern<InsertStridedSliceOp> {
4065 static constexpr int64_t vectorSizeFoldThreshold = 256;
4067 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4068 PatternRewriter &rewriter)
const override {
4072 Attribute vectorDestCst;
4076 VectorType destTy = destVector.getType();
4077 if (destTy.isScalable())
4081 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4082 !destVector.hasOneUse())
4086 Attribute sourceCst;
4096 if (op.hasNonUnitStrides())
4099 VectorType sliceVecTy = sourceValue.getType();
4100 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4101 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4102 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
4103 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
4111 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4112 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4113 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4114 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4115 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4116 MutableArrayRef<int64_t> currSlicePosition(
4117 currDestPosition.begin() + rankDifference, currDestPosition.end());
4118 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4121 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
4122 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
4123 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4124 "Invalid slice element");
4125 newValues[linearizedPosition] = *sliceValuesIt;
4138void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4139 RewritePatternSet &results, MLIRContext *context) {
4140 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4141 InsertStridedSliceConstantFolder>(context);
4144OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4145 if (getSourceVectorType() == getDestVectorType())
4146 return getValueToStore();
4155void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4156 Value
lhs, Value
rhs, Value acc) {
4161void OuterProductOp::print(OpAsmPrinter &p) {
4162 p <<
" " << getLhs() <<
", " << getRhs();
4164 p <<
", " << getAcc();
4167 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4170ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4171 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4178 if (operandsInfo.size() < 2)
4180 "expected at least 2 operands");
4181 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4182 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4185 "expected vector type for operand #1");
4189 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4190 vRHS.getScalableDims()[0]};
4191 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4192 vLHS.getElementType(), scalableDimsRes);
4195 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4196 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4200 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4201 result.attributes.append(
4202 OuterProductOp::getKindAttrName(
result.name),
4203 CombiningKindAttr::get(
result.getContext(),
4204 OuterProductOp::getDefaultKind()));
4210 (operandsInfo.size() > 2 &&
4215LogicalResult OuterProductOp::verify() {
4216 Type tRHS = getOperandTypeRHS();
4217 VectorType vLHS = getOperandVectorTypeLHS(),
4218 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4219 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4221 if (vLHS.getRank() != 1)
4222 return emitOpError(
"expected 1-d vector for operand #1");
4226 if (vRHS.getRank() != 1)
4227 return emitOpError(
"expected 1-d vector for operand #2");
4228 if (vRES.getRank() != 2)
4230 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4231 return emitOpError(
"expected #1 operand dim to match result dim #1");
4232 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4233 return emitOpError(
"expected #2 operand dim to match result dim #2");
4234 if (vLHS.isScalable() && !vRHS.isScalable()) {
4238 "expected either both or only #2 operand dim to be scalable");
4242 if (vRES.getRank() != 1)
4244 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4245 return emitOpError(
"expected #1 operand dim to match result dim #1");
4248 if (vACC && vACC != vRES)
4249 return emitOpError(
"expected operand #3 of same type as result type");
4251 if (!getKindAttr()) {
4252 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4253 "'vector.kind<add>')");
4258 return emitOpError(
"unsupported outerproduct type");
4267Type OuterProductOp::getExpectedMaskType() {
4268 auto vecType = this->getResultVectorType();
4269 return VectorType::get(vecType.getShape(),
4270 IntegerType::get(vecType.getContext(), 1),
4271 vecType.getScalableDims());
4285 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4287 shape.reserve(vectorType.getRank());
4289 for (
unsigned e = offsets.size(); idx < e; ++idx)
4290 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4291 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4292 shape.push_back(vectorType.getShape()[idx]);
4294 return VectorType::get(
shape, vectorType.getElementType(),
4295 vectorType.getScalableDims());
4298void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4299 Value source, ArrayRef<int64_t> offsets,
4300 ArrayRef<int64_t> sizes,
4301 ArrayRef<int64_t> strides) {
4302 result.addOperands(source);
4308 offsetsAttr, sizesAttr, stridesAttr));
4309 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4311 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4313 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4317LogicalResult ExtractStridedSliceOp::verify() {
4318 auto type = getSourceVectorType();
4319 auto offsets = getOffsetsAttr();
4320 auto sizes = getSizesAttr();
4321 auto strides = getStridesAttr();
4322 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4324 "expected offsets, sizes and strides attributes of same size");
4326 auto shape = type.getShape();
4327 auto offName = getOffsetsAttrName();
4328 auto sizesName = getSizesAttrName();
4329 auto stridesName = getStridesAttrName();
4345 shape, offName, sizesName,
4350 offsets, sizes, strides);
4351 if (getResult().
getType() != resultType)
4352 return emitOpError(
"expected result type to be ") << resultType;
4354 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4355 if (type.getScalableDims()[idx]) {
4356 auto inputDim = type.getShape()[idx];
4357 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4358 if (inputDim != inputSize)
4361 << (
" to match the corresponding base size from the input "
4363 << inputSize << (
" vs ") << inputDim << (
")");
4376 auto getElement = [](
ArrayAttr array,
int idx) {
4377 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4379 ArrayAttr extractOffsets = op.getOffsets();
4382 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4384 if (op.getSourceVectorType().getRank() !=
4385 insertOp.getSourceVectorType().getRank())
4387 ArrayAttr insertOffsets = insertOp.getOffsets();
4388 ArrayAttr insertStrides = insertOp.getStrides();
4391 if (extractOffsets.size() > insertOffsets.size())
4393 bool patialoverlap =
false;
4394 bool disjoint =
false;
4396 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4397 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4399 int64_t start = getElement(insertOffsets, dim);
4400 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4401 int64_t offset = getElement(extractOffsets, dim);
4402 int64_t size = getElement(extractSizes, dim);
4404 if (start <= offset && offset < end) {
4407 if (offset + size > end)
4408 patialoverlap =
true;
4409 offsetDiffs.push_back(offset - start);
4416 if (!disjoint && !patialoverlap) {
4417 op.setOperand(insertOp.getValueToStore());
4420 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4426 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4441 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4446 if (op.hasNonUnitStrides())
4449 VectorType sourceVecTy = op.getSourceVectorType();
4453 VectorType sliceVecTy = op.getType();
4455 int64_t rank = sliceVecTy.getRank();
4467 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4469 sliceValues.reserve(sliceVecTy.getNumElements());
4473 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4475 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4476 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4478 assert(
static_cast<int64_t>(sliceValues.size()) ==
4479 sliceVecTy.getNumElements() &&
4480 "Invalid number of slice elements");
4484OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4485 if (getSourceVectorType() == getResult().
getType())
4492 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4499void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4521class StridedSliceFolder final
4522 :
public OpRewritePattern<ExtractStridedSliceOp> {
4524 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4526 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4527 PatternRewriter &rewriter)
const override {
4528 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4532 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4535 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4536 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4537 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4538 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4540 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4541 SmallVector<int64_t> combinedOffsets(newRank, 0);
4542 SmallVector<int64_t> combinedSizes(newRank);
4543 ArrayRef<int64_t> firstSourceShape =
4544 firstOp.getSourceVectorType().getShape();
4545 for (
unsigned i = 0; i < newRank; ++i) {
4546 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4547 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4548 combinedOffsets[i] = off1 + off2;
4550 if (i < secondSizes.size()) {
4551 combinedSizes[i] = secondSizes[i];
4552 }
else if (i < firstSizes.size()) {
4553 combinedSizes[i] = firstSizes[i];
4555 combinedSizes[i] = firstSourceShape[i];
4559 SmallVector<int64_t> combinedStrides(newRank, 1);
4561 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4579class StridedSliceCreateMaskFolder final
4580 :
public OpRewritePattern<ExtractStridedSliceOp> {
4584 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4585 PatternRewriter &rewriter)
const override {
4586 Location loc = extractStridedSliceOp.getLoc();
4590 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4594 if (extractStridedSliceOp.hasNonUnitStrides())
4597 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4599 SmallVector<int64_t> sliceOffsets;
4602 SmallVector<int64_t> sliceSizes;
4606 SmallVector<Value> sliceMaskDimSizes;
4607 sliceMaskDimSizes.reserve(maskDimSizes.size());
4611 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4612 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4616 IntegerAttr offsetAttr =
4618 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4619 Value sliceMaskDimSize =
4620 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4621 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4626 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4630 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4638class StridedSliceConstantMaskFolder final
4639 :
public OpRewritePattern<ExtractStridedSliceOp> {
4643 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4644 PatternRewriter &rewriter)
const override {
4647 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4648 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4649 if (!constantMaskOp)
4652 if (extractStridedSliceOp.hasNonUnitStrides())
4655 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4657 SmallVector<int64_t> sliceOffsets;
4660 SmallVector<int64_t> sliceSizes;
4664 SmallVector<int64_t> sliceMaskDimSizes;
4665 sliceMaskDimSizes.reserve(maskDimSizes.size());
4666 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4667 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4668 int64_t sliceMaskDimSize = std::max(
4669 static_cast<int64_t
>(0),
4670 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4671 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4674 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4675 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4676 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4679 if (llvm::is_contained(sliceMaskDimSizes, 0))
4680 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4685 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4693class StridedSliceBroadcast final
4694 :
public OpRewritePattern<ExtractStridedSliceOp> {
4698 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4699 PatternRewriter &rewriter)
const override {
4705 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4706 auto dstVecType = llvm::cast<VectorType>(op.getType());
4707 unsigned dstRank = dstVecType.getRank();
4708 unsigned rankDiff = dstRank - srcRank;
4712 bool needsSlice =
false;
4713 for (
unsigned i = 0; i < srcRank; i++) {
4714 if (srcVecType.getDimSize(i) != 1 &&
4715 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4722 SmallVector<int64_t> offsets =
4724 SmallVector<int64_t> sizes =
4726 for (
unsigned i = 0; i < srcRank; i++) {
4727 if (srcVecType.getDimSize(i) == 1) {
4735 source = ExtractStridedSliceOp::create(
4736 rewriter, op->getLoc(), source, offsets, sizes,
4745class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4749 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4750 PatternRewriter &rewriter)
const override {
4752 Value splat = getScalarSplatSource(op.getSource());
4776class ContiguousExtractStridedSliceToExtract final
4777 :
public OpRewritePattern<ExtractStridedSliceOp> {
4781 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4782 PatternRewriter &rewriter)
const override {
4783 if (op.hasNonUnitStrides())
4785 Value source = op.getOperand();
4786 auto sourceType = cast<VectorType>(source.
getType());
4787 if (sourceType.isScalable() || sourceType.getRank() == 0)
4796 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4797 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4804 if (numOffsets == 0)
4809 if (numOffsets == sourceType.getRank() &&
4810 static_cast<int>(sizes.size()) == sourceType.getRank())
4814 for (
int i = 0; i < numOffsets; ++i) {
4822 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4823 sizes[numOffsets] == 1) {
4828 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4829 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4838void ExtractStridedSliceOp::getCanonicalizationPatterns(
4839 RewritePatternSet &results, MLIRContext *context) {
4842 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4843 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4844 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4853void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4854 VectorType vectorType, Value source,
4856 AffineMapAttr permutationMapAttr,
4859 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4861 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4862 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4863 *padding, Value(), inBoundsAttr);
4867void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4868 VectorType vectorType, Value source,
4870 AffineMap permutationMap,
4871 std::optional<ArrayRef<bool>> inBounds) {
4872 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4873 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4876 SmallVector<bool>(vectorType.getRank(),
false));
4877 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4879 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4880 build(builder,
result, vectorType, source,
indices, *padding,
4881 permutationMapAttr, inBoundsAttr);
4885void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4886 VectorType vectorType, Value source,
4888 std::optional<ArrayRef<bool>> inBounds) {
4890 llvm::cast<ShapedType>(source.
getType()), vectorType);
4891 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4892 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4895 SmallVector<bool>(vectorType.getRank(),
false));
4896 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4898 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4899 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4901 Value(), inBoundsAttr);
4904template <
typename EmitFun>
4908 for (
auto expr : permutationMap.
getResults()) {
4909 auto dim = dyn_cast<AffineDimExpr>(expr);
4910 auto zero = dyn_cast<AffineConstantExpr>(expr);
4912 if (zero.getValue() != 0) {
4914 "requires a projected permutation_map (at most one dim or the zero "
4915 "constant can appear in each result)");
4920 return emitOpError(
"requires a projected permutation_map (at most one "
4921 "dim or the zero constant can appear in each result)");
4923 if (seen[dim.getPosition()]) {
4925 "requires a permutation_map that is a permutation (found one dim "
4926 "used more than once)");
4928 seen[dim.getPosition()] =
true;
4935 VectorType vectorType, VectorType maskType,
4936 VectorType inferredMaskType,
AffineMap permutationMap,
4938 if (op->hasAttr(
"masked")) {
4939 return op->emitOpError(
"masked attribute has been removed. "
4940 "Use in_bounds instead.");
4943 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4944 return op->emitOpError(
4945 "requires source to be a memref or ranked tensor type");
4947 auto elementType = shapedType.getElementType();
4949 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4951 unsigned sourceVecSize =
4953 vectorElementType.getShape().back();
4954 unsigned resultVecSize =
4956 vectorType.getShape().back();
4957 if (resultVecSize % sourceVecSize != 0)
4958 return op->emitOpError(
4959 "requires the bitwidth of the minor 1-D vector to be an integral "
4960 "multiple of the bitwidth of the minor 1-D vector of the source");
4962 unsigned sourceVecEltRank = vectorElementType.getRank();
4963 unsigned resultVecRank = vectorType.getRank();
4964 if (sourceVecEltRank > resultVecRank)
4965 return op->emitOpError(
4966 "requires source vector element and vector result ranks to match.");
4967 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4970 return op->emitOpError(
"requires a permutation_map with result dims of "
4971 "the same rank as the vector type");
4974 return op->emitOpError(
"does not support masks with vector element type");
4977 unsigned minorSize =
4978 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4979 unsigned resultVecSize =
4982 return op->emitOpError(
4983 "requires the bitwidth of the minor 1-D vector to be an integral "
4984 "multiple of the bitwidth of the source element type");
4988 return op->emitOpError(
"requires a permutation_map with result dims of "
4989 "the same rank as the vector type");
4993 return op->emitOpError(
"requires permutation_map without symbols");
4995 if (permutationMap.
getNumInputs() != shapedType.getRank())
4996 return op->emitOpError(
"requires a permutation_map with input dims of the "
4997 "same rank as the source type");
4999 if (maskType && maskType != inferredMaskType)
5000 return op->emitOpError(
"inferred mask type (")
5001 << inferredMaskType <<
") and mask operand type (" << maskType
5005 return op->emitOpError(
"expects the in_bounds attr of same rank "
5006 "as permutation_map results: ")
5007 << AffineMapAttr::get(permutationMap)
5008 <<
" vs inBounds of size: " << inBounds.size();
5015 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5016 if (op.getPermutationMap().isMinorIdentity())
5017 elidedAttrs.push_back(op.getPermutationMapAttrName());
5019 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
5020 elidedAttrs.push_back(op.getInBoundsAttrName());
5024void TransferReadOp::print(OpAsmPrinter &p) {
5027 p <<
", " << getMask();
5034 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
5036 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
5041 if (maskShape.empty())
5042 maskShape.push_back(1);
5047 return VectorType::get(maskShape, i1Type, scalableDims);
5064 if (hasMask.succeeded()) {
5071 if (types.size() != 2)
5072 return parser.
emitError(typesLoc,
"requires two types");
5074 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5075 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5076 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5077 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5079 return parser.
emitError(typesLoc,
"requires vector type");
5080 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
5084 if (shapedType.getRank() <
5087 "expected a custom permutation_map when "
5088 "rank(source) != rank(destination)");
5090 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5092 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5094 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
5095 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5096 if (!inBoundsAttr) {
5097 result.addAttribute(inBoundsAttrName,
5106 if (hasMask.succeeded()) {
5107 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5109 maskInfo.
location,
"does not support masks with vector element type");
5112 "expected the same rank for the vector and the "
5113 "results of the permutation map");
5121 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5123 {1, static_cast<int32_t>(indexInfo.size()), 1,
5124 static_cast<int32_t>(hasMask.succeeded())}));
5128LogicalResult TransferReadOp::verify() {
5130 ShapedType shapedType = getShapedType();
5132 VectorType maskType = getMaskType();
5133 auto paddingType = getPadding().getType();
5134 auto permutationMap = getPermutationMap();
5135 VectorType inferredMaskType =
5138 auto sourceElementType = shapedType.getElementType();
5140 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5141 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5144 shapedType, vectorType, maskType,
5145 inferredMaskType, permutationMap, getInBounds())))
5148 if (
auto sourceVectorElementType =
5149 llvm::dyn_cast<VectorType>(sourceElementType)) {
5152 if (sourceVectorElementType != paddingType)
5154 "requires source element type and padding type to match.");
5158 if (!VectorType::isValidElementType(paddingType))
5159 return emitOpError(
"requires valid padding vector elemental type");
5162 if (paddingType != sourceElementType)
5164 "requires formal padding and source of the same elemental type");
5175Type TransferReadOp::getExpectedMaskType() {
5182VectorType TransferReadOp::getVectorType() {
5183 return cast<VectorType>(getVector().
getType());
5186template <
typename TransferOp>
5190 if (op.getShapedType().isDynamicDim(indicesIdx))
5194 if (!cstOp.has_value())
5197 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5198 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5200 return cstOp.value() + vectorSize <= sourceSize;
5203template <
typename TransferOp>
5207 if (op.getTransferRank() == 0)
5210 bool changed =
false;
5212 newInBounds.reserve(op.getTransferRank());
5217 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5219 if (op.isDimInBounds(i)) {
5220 newInBounds.push_back(
true);
5225 bool inBounds =
false;
5226 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5229 dimExpr.getPosition());
5230 nonBcastDims.push_back(i);
5233 newInBounds.push_back(inBounds);
5235 changed |= inBounds;
5241 bool allNonBcastDimsInBounds = llvm::all_of(
5242 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5243 if (allNonBcastDimsInBounds) {
5245 changed |= !newInBounds[idx];
5246 newInBounds[idx] =
true;
5254 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5258template <
typename TransferOp>
5260 auto mask = op.getMask();
5267 op.getMaskMutable().clear();
5281static Value foldRAW(TransferReadOp readOp) {
5282 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5284 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5287 return defWrite.getVector();
5289 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5290 cast<VectorTransferOpInterface>(readOp.getOperation())))
5292 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5297OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5298 if (Value vec = foldRAW(*
this))
5309 return OpFoldResult();
5312std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5316void TransferReadOp::getEffects(
5317 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5319 if (llvm::isa<MemRefType>(getShapedType()))
5320 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5321 SideEffects::DefaultResource::get());
5325 if (hasPureTensorSemantics())
5332static AffineMap inverseWithUnusedDims(AffineMap map) {
5334 "expected a projected permutation map");
5339 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5369struct TransferReadAfterWriteToBroadcast
5370 :
public OpRewritePattern<TransferReadOp> {
5373 LogicalResult matchAndRewrite(TransferReadOp readOp,
5374 PatternRewriter &rewriter)
const override {
5375 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5379 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5383 if (readOp.getMask() || defWrite.getMask())
5386 if (readOp.getIndices() != defWrite.getIndices())
5389 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5393 if (readOp.getTransferChunkAccessed() !=
5394 defWrite.getTransferChunkAccessed())
5401 AffineMap readMap = readOp.getPermutationMap();
5402 AffineMap writeMap = defWrite.getPermutationMap();
5403 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5404 AffineMap composedMap = readMap.
compose(invWriteMap);
5418 int64_t numBroadcastedDims = broadcastedDims.size();
5419 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5421 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5422 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5423 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5424 invPerm[effectiveDim] = idx;
5429 VectorType readVecTy = readOp.getVectorType();
5431 auto broadcastedVecTy =
5433 readVecTy.getElementType(),
5436 Value vec = defWrite.getVector();
5437 Location loc = readOp.getLoc();
5438 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5445void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5446 MLIRContext *context) {
5447 results.
add<TransferReadAfterWriteToBroadcast>(context);
5450FailureOr<std::optional<SmallVector<Value>>>
5451TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5452 if (!hasPureBufferSemantics())
5463void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5465 AffineMapAttr permutationMapAttr,
5468 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5469 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5470 mask, inBoundsAttr);
5474void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5476 AffineMapAttr permutationMapAttr,
5478 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5479 Value(), inBoundsAttr);
5484void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5486 AffineMap permutationMap,
5487 std::optional<ArrayRef<bool>> inBounds) {
5488 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5490 (inBounds && !inBounds.value().empty())
5493 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5494 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5495 Value(), inBoundsAttr);
5500void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5502 std::optional<ArrayRef<bool>> inBounds) {
5503 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5505 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5506 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5509ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5510 OperationState &
result) {
5513 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5514 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5515 SmallVector<Type, 2> types;
5516 OpAsmParser::UnresolvedOperand maskInfo;
5522 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5527 if (types.size() != 2)
5528 return parser.
emitError(typesLoc,
"requires two types");
5530 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5532 return parser.
emitError(typesLoc,
"requires vector type");
5533 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5534 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5535 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5536 auto permMapAttrName =
5537 TransferWriteOp::getPermutationMapAttrName(
result.name);
5538 auto permMapAttr =
result.attributes.get(permMapAttrName);
5541 if (shapedType.getRank() <
5544 "expected a custom permutation_map when "
5545 "rank(source) != rank(destination)");
5547 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5549 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5551 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5552 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5553 if (!inBoundsAttr) {
5554 result.addAttribute(inBoundsAttrName,
5562 if (hasMask.succeeded()) {
5563 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5565 maskInfo.
location,
"does not support masks with vector element type");
5568 "expected the same rank for the vector and the "
5569 "results of the permutation map");
5575 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5577 {1, 1, static_cast<int32_t>(indexInfo.size()),
5578 static_cast<int32_t>(hasMask.succeeded())}));
5579 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5583void TransferWriteOp::print(OpAsmPrinter &p) {
5586 p <<
", " << getMask();
5591LogicalResult TransferWriteOp::verify() {
5593 ShapedType shapedType = getShapedType();
5595 VectorType maskType = getMaskType();
5596 auto permutationMap = getPermutationMap();
5597 VectorType inferredMaskType =
5601 if (llvm::size(
getIndices()) != shapedType.getRank())
5602 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5606 if (hasBroadcastDim())
5607 return emitOpError(
"should not have broadcast dimensions");
5610 shapedType, vectorType, maskType,
5611 inferredMaskType, permutationMap, getInBounds())))
5624Type TransferWriteOp::getExpectedMaskType() {
5631Value TransferWriteOp::getVector() {
return getOperand(0); }
5632VectorType TransferWriteOp::getVectorType() {
5633 return cast<VectorType>(getValueToStore().
getType());
5656static LogicalResult foldReadInitWrite(TransferWriteOp write,
5657 ArrayRef<Attribute>,
5658 SmallVectorImpl<OpFoldResult> &results) {
5660 if (write.getTransferRank() == 0)
5662 auto rankedTensorType =
5663 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5665 if (!rankedTensorType)
5668 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5672 if (read.getTransferRank() == 0)
5675 if (!read.getPermutationMap().isMinorIdentity() ||
5676 !write.getPermutationMap().isMinorIdentity())
5679 if (read.getTransferRank() != write.getTransferRank())
5682 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5685 if (read.getBase().getType() != rankedTensorType)
5688 if (read.getVectorType() != write.getVectorType())
5691 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5694 auto isNotConstantZero = [](Value v) {
5696 return !cstOp.has_value() || cstOp.value() != 0;
5698 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5699 llvm::any_of(write.getIndices(), isNotConstantZero))
5702 results.push_back(read.getBase());
5706static bool checkSameValueWAR(vector::TransferReadOp read,
5707 vector::TransferWriteOp write) {
5708 return read.getBase() == write.getBase() &&
5709 read.getIndices() == write.getIndices() &&
5710 read.getPermutationMap() == write.getPermutationMap() &&
5711 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5728static LogicalResult foldWAR(TransferWriteOp write,
5729 SmallVectorImpl<OpFoldResult> &results) {
5730 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5732 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5736 if (!checkSameValueWAR(read, write))
5738 results.push_back(read.getBase());
5742LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5743 SmallVectorImpl<OpFoldResult> &results) {
5744 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5746 if (succeeded(foldWAR(*
this, results)))
5758std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5762void TransferWriteOp::getEffects(
5763 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5765 if (llvm::isa<MemRefType>(getShapedType()))
5766 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5767 SideEffects::DefaultResource::get());
5771 if (hasPureTensorSemantics())
5801class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
5804 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5805 PatternRewriter &rewriter)
const override {
5806 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5808 vector::TransferWriteOp writeToModify = writeOp;
5810 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5814 writeToModify.getBaseMutable().assign(defWrite.getBase());
5819 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5820 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5824 if (!defWrite->hasOneUse())
5826 writeToModify = defWrite;
5827 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5856struct SwapExtractSliceOfTransferWrite
5857 :
public OpRewritePattern<tensor::InsertSliceOp> {
5861 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5862 PatternRewriter &rewriter)
const override {
5863 if (!insertOp.hasUnitStride())
5866 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5867 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5869 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5870 if (!transferOp || !transferOp->hasOneUse())
5875 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5877 "use-def chain is rank-reducing");
5881 if (!extractOp.hasZeroOffset()) {
5883 "ExtractSliceOp has non-zero offset");
5887 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5888 return getConstantIntValue(value) == static_cast<int64_t>(0);
5891 "TranferWriteOp has non-zero offset");
5895 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5897 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5900 for (
auto [insertSize, extractSize] :
5901 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5904 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5909 assert(transferOp.getVectorType().hasStaticShape() &&
5910 "expected vector to have a static shape");
5911 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
5913 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5914 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5916 insertOp,
"TransferWriteOp may not write the full tensor.");
5921 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
5922 auto newExtractOp = tensor::ExtractSliceOp::create(
5923 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5924 insertOp.getDest(), insertOp.getMixedOffsets(),
5925 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5926 auto newTransferWriteOp = TransferWriteOp::create(
5927 rewriter, transferOp.getLoc(), transferOp.getVector(),
5928 newExtractOp.getResult(), transferOp.getIndices(),
5929 transferOp.getPermutationMapAttr(),
5932 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5940void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5941 MLIRContext *context) {
5942 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5945FailureOr<std::optional<SmallVector<Value>>>
5946TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5947 if (!hasPureBufferSemantics())
5957static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5959 MemRefType memRefTy) {
5962 if (!vecTy.isScalable() &&
5963 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5966 if (!memRefTy.isLastDimUnitStride())
5967 return op->
emitOpError(
"most minor memref dim must have unit stride");
5971LogicalResult vector::LoadOp::verify() {
5975 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5978 if (memRefTy.getRank() < resVecTy.getRank())
5980 "destination memref has lower rank than the result vector");
5983 Type memElemTy = memRefTy.getElementType();
5984 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5985 if (memVecTy != resVecTy)
5986 return emitOpError(
"base memref and result vector types should match");
5987 memElemTy = memVecTy.getElementType();
5990 if (resVecTy.getElementType() != memElemTy)
5991 return emitOpError(
"base and result element types should match");
5992 if (llvm::size(
getIndices()) != memRefTy.getRank())
5993 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5997OpFoldResult LoadOp::fold(FoldAdaptor) {
6000 return OpFoldResult();
6003std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6007FailureOr<std::optional<SmallVector<Value>>>
6008LoadOp::bubbleDownCasts(OpBuilder &builder) {
6017LogicalResult vector::StoreOp::verify() {
6021 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
6024 if (memRefTy.getRank() < valueVecTy.getRank())
6025 return emitOpError(
"source memref has lower rank than the vector to store");
6028 Type memElemTy = memRefTy.getElementType();
6029 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6030 if (memVecTy != valueVecTy)
6032 "base memref and valueToStore vector types should match");
6033 memElemTy = memVecTy.getElementType();
6036 if (valueVecTy.getElementType() != memElemTy)
6037 return emitOpError(
"base and valueToStore element type should match");
6038 if (llvm::size(
getIndices()) != memRefTy.getRank())
6039 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6043LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6044 SmallVectorImpl<OpFoldResult> &results) {
6048std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6052FailureOr<std::optional<SmallVector<Value>>>
6053StoreOp::bubbleDownCasts(OpBuilder &builder) {
6062LogicalResult MaskedLoadOp::verify() {
6063 VectorType maskVType = getMaskVectorType();
6064 VectorType passVType = getPassThruVectorType();
6071 if (llvm::size(
getIndices()) != memType.getRank())
6072 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6073 if (resVType.getShape() != maskVType.getShape())
6074 return emitOpError(
"expected result shape to match mask shape");
6075 if (resVType != passVType)
6076 return emitOpError(
"expected pass_thru of same type as result type");
6081class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
6084 LogicalResult matchAndRewrite(MaskedLoadOp
load,
6085 PatternRewriter &rewriter)
const override {
6097 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
6102void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6103 MLIRContext *context) {
6104 results.
add<MaskedLoadFolder>(context);
6107OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6110 return OpFoldResult();
6113FailureOr<std::optional<SmallVector<Value>>>
6114MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6123LogicalResult MaskedStoreOp::verify() {
6124 VectorType maskVType = getMaskVectorType();
6131 if (llvm::size(
getIndices()) != memType.getRank())
6132 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6133 if (valueVType.getShape() != maskVType.getShape())
6134 return emitOpError(
"expected valueToStore shape to match mask shape");
6139class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6142 LogicalResult matchAndRewrite(MaskedStoreOp store,
6143 PatternRewriter &rewriter)
const override {
6147 store, store.getValueToStore(), store.getBase(), store.getIndices());
6155 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6160void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6161 MLIRContext *context) {
6162 results.
add<MaskedStoreFolder>(context);
6165LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6166 SmallVectorImpl<OpFoldResult> &results) {
6170FailureOr<std::optional<SmallVector<Value>>>
6171MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6180LogicalResult GatherOp::verify() {
6181 VectorType indVType = getIndexVectorType();
6182 VectorType maskVType = getMaskVectorType();
6184 ShapedType baseType = getBaseType();
6186 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6187 return emitOpError(
"requires base to be a memref or ranked tensor type");
6192 if (llvm::size(getOffsets()) != baseType.getRank())
6193 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6194 if (resVType.getShape() != indVType.getShape())
6195 return emitOpError(
"expected result dim to match indices dim");
6196 if (resVType.getShape() != maskVType.getShape())
6197 return emitOpError(
"expected result dim to match mask dim");
6198 if (resVType != getPassThruVectorType())
6199 return emitOpError(
"expected pass_thru of same type as result type");
6200 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6202 "alignment is only supported for memref bases, not tensor bases");
6211Type GatherOp::getExpectedMaskType() {
6212 auto vecType = this->getIndexVectorType();
6213 return VectorType::get(vecType.getShape(),
6214 IntegerType::get(vecType.getContext(), 1),
6215 vecType.getScalableDims());
6218std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6223static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6224 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6225 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6231 DenseIntElementsAttr elements;
6236 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6240class GatherFolder final :
public OpRewritePattern<GatherOp> {
6243 LogicalResult matchAndRewrite(GatherOp gather,
6244 PatternRewriter &rewriter)
const override {
6249 rewriter.
replaceOp(gather, gather.getPassThru());
6254 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6260class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6263 LogicalResult matchAndRewrite(GatherOp op,
6264 PatternRewriter &rewriter)
const override {
6265 if (!isa<MemRefType>(op.getBase().getType()))
6268 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6272 op.getOffsets(), op.getMask(),
6279void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6280 MLIRContext *context) {
6281 results.
add<GatherFolder, FoldContiguousGather>(context);
6284FailureOr<std::optional<SmallVector<Value>>>
6285GatherOp::bubbleDownCasts(OpBuilder &builder) {
6294LogicalResult ScatterOp::verify() {
6295 VectorType indVType = getIndexVectorType();
6296 VectorType maskVType = getMaskVectorType();
6298 ShapedType baseType = getBaseType();
6300 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6301 return emitOpError(
"requires base to be a memref or ranked tensor type");
6306 if (llvm::size(getOffsets()) != baseType.getRank())
6307 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6308 if (valueVType.getShape() != indVType.getShape())
6309 return emitOpError(
"expected valueToStore dim to match indices dim");
6310 if (valueVType.getShape() != maskVType.getShape())
6311 return emitOpError(
"expected valueToStore dim to match mask dim");
6312 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6314 "alignment is only supported for memref bases, not tensor bases");
6319class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6322 LogicalResult matchAndRewrite(ScatterOp scatter,
6323 PatternRewriter &rewriter)
const override {
6324 ShapedType baseType = scatter.getBaseType();
6325 bool isMemRef = isa<MemRefType>(baseType);
6326 if (!isMemRef && !isa<RankedTensorType>(baseType))
6339 rewriter.
replaceOp(scatter, scatter.getBase());
6344 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6350class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6353 LogicalResult matchAndRewrite(ScatterOp op,
6354 PatternRewriter &rewriter)
const override {
6357 if (!isa<MemRefType>(op.getBase().getType()))
6360 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6364 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6370void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6371 MLIRContext *context) {
6372 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6375FailureOr<std::optional<SmallVector<Value>>>
6376ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6385LogicalResult ExpandLoadOp::verify() {
6386 VectorType maskVType = getMaskVectorType();
6387 VectorType passVType = getPassThruVectorType();
6394 if (llvm::size(
getIndices()) != memType.getRank())
6395 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6396 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6397 return emitOpError(
"expected result dim to match mask dim");
6398 if (resVType != passVType)
6399 return emitOpError(
"expected pass_thru of same type as result type");
6404class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6407 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6408 PatternRewriter &rewriter)
const override {
6412 expand, expand.getType(), expand.getBase(), expand.getIndices());
6415 rewriter.
replaceOp(expand, expand.getPassThru());
6420 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6425void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6426 MLIRContext *context) {
6427 results.
add<ExpandLoadFolder>(context);
6430FailureOr<std::optional<SmallVector<Value>>>
6431ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6440LogicalResult CompressStoreOp::verify() {
6441 VectorType maskVType = getMaskVectorType();
6448 if (llvm::size(
getIndices()) != memType.getRank())
6449 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6450 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6451 return emitOpError(
"expected valueToStore dim to match mask dim");
6456class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6459 LogicalResult matchAndRewrite(CompressStoreOp compress,
6460 PatternRewriter &rewriter)
const override {
6464 compress, compress.getValueToStore(), compress.getBase(),
6465 compress.getIndices());
6473 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6478void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6479 MLIRContext *context) {
6480 results.
add<CompressStoreFolder>(context);
6483FailureOr<std::optional<SmallVector<Value>>>
6484CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6493void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6495 setResultRanges(getResult(), argRanges.front());
6498std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6499 return llvm::to_vector<4>(getResultVectorType().
getShape());
6502LogicalResult ShapeCastOp::verify() {
6504 VectorType sourceType = getSourceVectorType();
6505 VectorType resultType = getResultVectorType();
6513 int64_t sourceNElms = sourceType.getNumElements();
6514 int64_t resultNElms = resultType.getNumElements();
6515 if (sourceNElms != resultNElms) {
6516 return emitOpError() <<
"has different number of elements at source ("
6517 << sourceNElms <<
") and result (" << resultNElms
6522 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6523 int64_t resultNScalableDims = resultType.getNumScalableDims();
6524 if (sourceNScalableDims != resultNScalableDims)
6525 return emitOpError() <<
"has different number of scalable dims at source ("
6526 << sourceNScalableDims <<
") and result ("
6527 << resultNScalableDims <<
")";
6536static bool isOrderPreserving(TransposeOp transpose) {
6537 ArrayRef<int64_t> permutation = transpose.getPermutation();
6538 VectorType sourceType = transpose.getSourceVectorType();
6539 ArrayRef<int64_t> inShape = sourceType.getShape();
6540 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6541 auto isNonScalableUnitDim = [&](int64_t dim) {
6542 return inShape[dim] == 1 && !inDimIsScalable[dim];
6544 int64_t current = 0;
6545 for (
auto p : permutation) {
6546 if (!isNonScalableUnitDim(p)) {
6556OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6558 VectorType resultType =
getType();
6561 if (getSource().
getType() == resultType)
6565 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6566 setOperand(precedingShapeCast.getSource());
6571 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6572 if (isOrderPreserving(transpose)) {
6573 setOperand(transpose.getVector());
6581 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6582 if (bcastOp.getSourceType() == resultType)
6583 return bcastOp.getSource();
6587 if (
auto denseAttr =
6588 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6589 return denseAttr.reshape(
getType());
6605static VectorType trimTrailingOneDims(VectorType oldType) {
6606 ArrayRef<int64_t> oldShape = oldType.getShape();
6607 ArrayRef<int64_t> newShape = oldShape;
6609 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6610 ArrayRef<bool> newScalableDims = oldScalableDims;
6612 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6613 newShape = newShape.drop_back(1);
6614 newScalableDims = newScalableDims.drop_back(1);
6619 if (newShape.empty()) {
6620 newShape = oldShape.take_back();
6621 newScalableDims = oldScalableDims.take_back();
6624 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6639class ShapeCastCreateMaskFolderTrailingOneDim final
6640 :
public OpRewritePattern<ShapeCastOp> {
6644 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6645 PatternRewriter &rewriter)
const override {
6646 Value shapeOpSrc = shapeOp->getOperand(0);
6647 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6648 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6649 if (!createMaskOp && !constantMaskOp)
6652 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6653 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6655 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6656 if (newVecType != shapeOpResTy)
6659 auto numDimsToDrop =
6660 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6667 auto maskOperands = createMaskOp.getOperands();
6668 auto numMaskOperands = maskOperands.size();
6671 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6673 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6674 if (!constant || (constant.value() != 1))
6677 SmallVector<Value> newMaskOperands =
6678 maskOperands.drop_back(numDimsToDrop);
6685 if (constantMaskOp) {
6686 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6687 auto numMaskOperands = maskDimSizes.size();
6690 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6692 if (maskDimSizes[i] != 1)
6696 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6709int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6710 ArrayRef<int64_t> dstShape) {
6711 int stretchingFactor = 1;
6712 int numLeadingDims = dstShape.size() - srcShape.size();
6713 for (
int i = 0, e = srcShape.size(); i < e; i++) {
6714 int64_t dstDim = dstShape[numLeadingDims + i];
6715 if (srcShape[i] == 1 && dstDim != 1) {
6716 stretchingFactor *= dstDim;
6719 return stretchingFactor;
6723class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6727 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6728 PatternRewriter &rewriter)
const override {
6730 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6734 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6735 bool srcIsScalar = !srcVectorType;
6743 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6744 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6745 ArrayRef<int64_t> srcShape =
6746 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6747 ArrayRef<int64_t> broadcastShape =
6748 broadcastOp.getResultVectorType().getShape();
6752 BroadcastableToResult::Success) {
6760 if (srcVectorType.getNumElements() != 1) {
6761 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6762 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6769 broadcastOp.getSource());
6788class FoldShapeCastOfFromElements final :
public OpRewritePattern<ShapeCastOp> {
6792 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6793 PatternRewriter &rewriter)
const override {
6794 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
6799 shapeCastOp, shapeCastOp.getResultVectorType(),
6800 fromElements.getElements());
6807void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6808 MLIRContext *context) {
6809 results.
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
6810 FoldShapeCastOfFromElements>(context);
6817LogicalResult BitCastOp::verify() {
6818 auto sourceVectorType = getSourceVectorType();
6819 auto resultVectorType = getResultVectorType();
6821 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6822 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6823 return emitOpError(
"dimension size mismatch at: ") << i;
6826 DataLayout dataLayout = DataLayout::closest(*
this);
6827 auto sourceElementBits =
6829 auto resultElementBits =
6832 if (sourceVectorType.getRank() == 0) {
6833 if (sourceElementBits != resultElementBits)
6834 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6835 "types must be equal");
6836 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6837 resultElementBits * resultVectorType.getShape().back()) {
6839 "source/result bitwidth of the minor 1-D vectors must be equal");
6845OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6851 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6852 if (getResult().
getType() == otherOp.getSource().getType())
6853 return otherOp.getSource();
6855 setOperand(otherOp.getSource());
6859 Attribute sourceConstant = adaptor.getSource();
6860 if (!sourceConstant)
6863 Type srcElemType = getSourceVectorType().getElementType();
6864 Type dstElemType = getResultVectorType().getElementType();
6866 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6867 if (floatPack.isSplat()) {
6868 auto splat = floatPack.getSplatValue<FloatAttr>();
6871 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6872 uint32_t bits =
static_cast<uint32_t
>(
6873 splat.getValue().bitcastToAPInt().getZExtValue());
6875 bits = (bits << 16) | (bits & 0xffff);
6876 APInt intBits(32, bits);
6877 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6883 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6884 if (intPack.isSplat()) {
6885 auto splat = intPack.getSplatValue<IntegerAttr>();
6887 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.
isIntOrFloat()) {
6892 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6893 APInt intBits = splat.getValue().zext(dstBitWidth);
6896 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6897 intBits = (intBits << srcBitWidth) | intBits;
6911static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6912 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6913 SmallVector<int64_t, 8> res(memRefType.getShape());
6915 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6921void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
6923 result.addOperands(source);
6924 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6925 VectorType vectorType =
6926 VectorType::get(extractShape(memRefType),
6928 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6929 memRefType.getMemorySpace()));
6932LogicalResult TypeCastOp::verify() {
6933 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6934 if (!canonicalType.getLayout().isIdentity())
6935 return emitOpError(
"expects operand to be a memref with identity layout");
6936 if (!getResultMemRefType().getLayout().isIdentity())
6937 return emitOpError(
"expects result to be a memref with identity layout");
6938 if (getResultMemRefType().getMemorySpace() !=
6940 return emitOpError(
"expects result in same memory space");
6943 auto resultType = getResultMemRefType();
6947 "expects result and operand with same underlying scalar type: ")
6949 if (extractShape(sourceType) != extractShape(resultType))
6951 "expects concatenated result and operand shapes to be equal: ")
6960void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
6961 Value vector, ArrayRef<int64_t> permutation) {
6962 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6963 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6964 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6965 for (
unsigned i = 0; i < permutation.size(); ++i) {
6966 transposedShape[i] = vt.getShape()[permutation[i]];
6967 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6970 result.addOperands(vector);
6971 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6972 transposedScalableDims));
6973 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
6977OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6980 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6981 return splat.reshape(getResultVectorType());
6998 if (getSourceVectorType() == getResultVectorType() &&
6999 isOrderPreserving(*
this))
7005LogicalResult vector::TransposeOp::verify() {
7006 VectorType vectorType = getSourceVectorType();
7007 VectorType resultType = getResultVectorType();
7008 int64_t rank = resultType.getRank();
7009 if (vectorType.getRank() != rank)
7010 return emitOpError(
"vector result rank mismatch: ") << rank;
7012 ArrayRef<int64_t> perm = getPermutation();
7013 int64_t size = perm.size();
7015 return emitOpError(
"transposition length mismatch: ") << size;
7016 SmallVector<bool, 8> seen(rank,
false);
7017 for (
const auto &ta : llvm::enumerate(perm)) {
7018 if (ta.value() < 0 || ta.value() >= rank)
7019 return emitOpError(
"transposition index out of range: ") << ta.value();
7020 if (seen[ta.value()])
7021 return emitOpError(
"duplicate position index: ") << ta.value();
7022 seen[ta.value()] =
true;
7023 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7024 return emitOpError(
"dimension size mismatch at: ") << ta.value();
7029std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7030 return llvm::to_vector<4>(getResultVectorType().
getShape());
7033void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7035 setResultRanges(getResult(), argRanges.front());
7041class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
7045 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7046 PatternRewriter &rewriter)
const override {
7048 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7049 ArrayRef<int64_t> permutation2) {
7050 SmallVector<int64_t, 4>
result;
7051 for (
auto index : permutation2)
7052 result.push_back(permutation1[index]);
7057 vector::TransposeOp parentTransposeOp =
7058 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7059 if (!parentTransposeOp)
7062 SmallVector<int64_t, 4> permutation = composePermutations(
7063 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7066 transposeOp, transposeOp.getResult().
getType(),
7067 parentTransposeOp.getVector(), permutation);
7073class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
7077 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7078 PatternRewriter &rewriter)
const override {
7079 Value splat = getScalarSplatSource(transposeOp.getVector());
7084 transposeOp, transposeOp.getResultVectorType(), splat);
7090class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
7094 LogicalResult matchAndRewrite(TransposeOp transpOp,
7095 PatternRewriter &rewriter)
const override {
7096 Value transposeSrc = transpOp.getVector();
7097 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
7098 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
7099 if (!createMaskOp && !constantMaskOp)
7104 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7107 auto maskOperands = createMaskOp.getOperands();
7108 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7112 transpOp, transpOp.getResultVectorType(), newOperands);
7117 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7121 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7127class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
7131 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7132 PatternRewriter &rewriter)
const override {
7134 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7137 if (!isOrderPreserving(transposeOp))
7140 VectorType resultType = transposeOp.getType();
7147 shapeCastOp.getSource());
7166class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
7169 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7170 PatternRewriter &rewriter)
const override {
7171 auto fromElementsOp =
7172 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7173 if (!fromElementsOp)
7176 VectorType srcTy = fromElementsOp.getDest().getType();
7177 VectorType dstTy = transposeOp.getType();
7179 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7180 int64_t rank = srcTy.getRank();
7183 SmallVector<int64_t> inversePerm(rank, 0);
7184 for (int64_t i = 0; i < rank; ++i)
7185 inversePerm[permutation[i]] = i;
7187 ArrayRef<int64_t> srcShape = srcTy.getShape();
7188 ArrayRef<int64_t> dstShape = dstTy.getShape();
7189 SmallVector<int64_t> srcIdx(rank, 0);
7190 SmallVector<int64_t> dstIdx(rank, 0);
7194 auto elementsOld = fromElementsOp.getElements();
7195 SmallVector<Value> elementsNew;
7196 int64_t dstNumElements = dstTy.getNumElements();
7197 elementsNew.reserve(dstNumElements);
7201 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7205 for (int64_t j = 0; j < rank; ++j)
7206 srcIdx[j] = dstIdx[inversePerm[j]];
7208 int64_t srcLin =
linearize(srcIdx, srcStrides);
7210 elementsNew.push_back(elementsOld[srcLin]);
7244class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7247 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7248 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7250 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7251 PatternRewriter &rewriter)
const override {
7257 "not preceded by a broadcast");
7260 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7261 VectorType outputType = transpose.getResultVectorType();
7264 bool inputIsScalar = !inputType;
7265 if (inputIsScalar) {
7271 ArrayRef<int64_t> permutation = transpose.getPermutation();
7272 ArrayRef<int64_t> inputShape = inputType.getShape();
7273 int64_t inputRank = inputType.getRank();
7274 int64_t outputRank = transpose.getType().getRank();
7275 int64_t deltaRank = outputRank - inputRank;
7278 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7279 bool notOne = inputShape[inputIndex] != 1;
7280 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7281 bool groupEndFound = notOne || prevNotOne;
7282 if (groupEndFound) {
7283 int high = inputIndex + deltaRank;
7287 for (
int i = low; i < high; ++i) {
7288 if (permutation[i] < low || permutation[i] >= high) {
7290 transpose,
"permutation not local to group");
7304 vector::BroadcastableToResult::Success &&
7305 "not broadcastable directly to transpose output");
7316void vector::TransposeOp::getCanonicalizationPatterns(
7317 RewritePatternSet &results, MLIRContext *context) {
7318 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7319 FoldTransposeSplat, FoldTransposeFromElements,
7320 FoldTransposeBroadcast>(context);
7327void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7329 assert(kind == ConstantMaskKind::AllTrue ||
7330 kind == ConstantMaskKind::AllFalse);
7331 build(builder,
result, type,
7332 kind == ConstantMaskKind::AllTrue
7334 : SmallVector<int64_t>(type.getRank(), 0));
7337LogicalResult ConstantMaskOp::verify() {
7338 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7340 if (resultType.getRank() == 0) {
7341 if (getMaskDimSizes().size() != 1)
7342 return emitError(
"array attr must have length 1 for 0-D vectors");
7343 auto dim = getMaskDimSizes()[0];
7344 if (dim != 0 && dim != 1)
7345 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7350 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7352 "must specify array attr of size equal vector result rank");
7355 auto resultShape = resultType.getShape();
7356 auto resultScalableDims = resultType.getScalableDims();
7357 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7358 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7359 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7361 "array attr of size out of bounds of vector result dimension size");
7362 if (resultScalableDims[index] && maskDimSize != 0 &&
7363 maskDimSize != resultShape[index])
7365 "only supports 'none set' or 'all set' scalable dimensions");
7369 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7370 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7371 if (anyZeros && !allZeros)
7372 return emitOpError(
"expected all mask dim sizes to be zeros, "
7373 "as a result of conjunction with zero mask dim");
7377bool ConstantMaskOp::isAllOnesMask() {
7380 if (resultType.getRank() == 0) {
7381 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7382 return getMaskDimSizes()[0] == 1;
7384 for (
const auto [resultSize, maskDimSize] :
7385 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7386 if (maskDimSize < resultSize)
7392OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7393 ArrayRef<int64_t> bounds = getMaskDimSizes();
7396 auto createBoolSplat = [&](
bool x) {
7402 if (vectorSizes.empty()) {
7403 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7404 return createBoolSplat(bounds[0] == 1);
7407 if (bounds == vectorSizes)
7408 return createBoolSplat(
true);
7409 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7410 return createBoolSplat(
false);
7411 return OpFoldResult();
7418void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7420 ArrayRef<OpFoldResult> mixedOperands) {
7421 SmallVector<Value> operands =
7423 build(builder,
result, type, operands);
7426LogicalResult CreateMaskOp::verify() {
7427 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7429 if (vectorType.getRank() == 0) {
7430 if (getNumOperands() != 1)
7432 "must specify exactly one operand for 0-D create_mask");
7433 }
else if (getNumOperands() !=
7434 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7436 "must specify an operand for each result vector dimension");
7466class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7470 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7471 PatternRewriter &rewriter)
const override {
7472 VectorType maskType = createMaskOp.getVectorType();
7473 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7474 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7477 constexpr std::array<int64_t, 1> rankZeroShape{1};
7478 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7479 if (maskType.getRank() == 0) {
7480 maskTypeDimSizes = rankZeroShape;
7481 maskTypeDimScalableFlags = rankZeroScalableDims;
7486 SmallVector<int64_t, 4> constantDims;
7487 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7492 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7494 constantDims.push_back(*intSize);
7498 if (vscaleMultiplier < maskTypeDimSizes[i])
7500 constantDims.push_back(*vscaleMultiplier);
7507 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7508 value = std::clamp<int64_t>(value, 0, maskDimSize);
7511 if (llvm::is_contained(constantDims, 0))
7512 constantDims.assign(constantDims.size(), 0);
7523void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7524 MLIRContext *context) {
7525 results.
add<CreateMaskFolder>(context);
7533 OpBuilder &builder, OperationState &
result, Value mask,
7534 Operation *maskableOp,
7535 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7536 assert(maskRegionBuilder &&
7537 "builder callback for 'maskRegion' must be present");
7539 result.addOperands(mask);
7540 OpBuilder::InsertionGuard guard(builder);
7541 Region *maskRegion =
result.addRegion();
7543 maskRegionBuilder(builder, maskableOp);
7548 Value mask, Operation *maskableOp,
7549 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7550 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7556 Value mask, Value passthru, Operation *maskableOp,
7557 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7558 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7560 result.addOperands(passthru);
7561 result.addTypes(resultTypes);
7564ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7566 result.regions.reserve(1);
7567 Region &maskRegion = *
result.addRegion();
7572 OpAsmParser::UnresolvedOperand mask;
7577 OpAsmParser::UnresolvedOperand passthru;
7579 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7586 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7597 SmallVector<Type> resultTypes;
7600 result.types.append(resultTypes);
7606 if (parsePassthru.succeeded()) {
7607 if (resultTypes.empty())
7610 "expects a result if passthru operand is provided");
7619void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7620 p <<
" " << getMask();
7622 p <<
", " << getPassthru();
7626 Block *singleBlock = &getMaskRegion().getBlocks().front();
7633 p <<
" : " << getMask().getType();
7634 if (getNumResults() > 0)
7635 p <<
" -> " << getResultTypes();
7638void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7641 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7642 MaskOp>::ensureTerminator(region, builder, loc);
7648 if (isa<vector::YieldOp>(block.
back()))
7656 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7657 MaskOp>::ensureTerminator(region, builder, loc);
7663 Operation *maskedOp = &block.
front();
7664 opBuilder.setInsertionPointToEnd(&block);
7665 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7668LogicalResult MaskOp::verify() {
7670 Block &block = getMaskRegion().getBlocks().
front();
7672 return emitOpError(
"expects a terminator within the mask region");
7675 if (numMaskRegionOps > 2)
7676 return emitOpError(
"expects only one operation to mask");
7679 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7681 return emitOpError(
"expects a terminator within the mask region");
7683 if (terminator->getNumOperands() != getNumResults())
7685 "expects number of results to match mask region yielded values");
7688 if (numMaskRegionOps == 1)
7691 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7693 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7697 return emitOpError(
"expects number of results to match maskable operation "
7698 "number of results");
7700 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7701 return emitOpError(
"expects all the results from the MaskableOpInterface "
7702 "to match all the values returned by the terminator");
7704 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7706 "expects result type to match maskable operation result type");
7709 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7710 return emitOpError(
"multiple vector results not supported");
7713 Type expectedMaskType = maskableOp.getExpectedMaskType();
7714 if (getMask().
getType() != expectedMaskType)
7716 << expectedMaskType <<
" mask for the maskable operation";
7719 Value passthru = getPassthru();
7721 if (!maskableOp.supportsPassthru())
7723 "doesn't expect a passthru argument for this maskable operation");
7726 return emitOpError(
"expects result when passthru argument is provided");
7729 return emitOpError(
"expects passthru type to match result type");
7749static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7750 SmallVectorImpl<OpFoldResult> &results) {
7751 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7754 Block *block = maskOp.getMaskBlock();
7755 auto terminator = cast<vector::YieldOp>(block->
front());
7756 if (terminator.getNumOperands() == 0)
7760 llvm::append_range(results, terminator.getOperands());
7764LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7765 SmallVectorImpl<OpFoldResult> &results) {
7766 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7776 Operation *maskableOp = getMaskableOp();
7782 llvm::append_range(results, maskableOp->
getResults());
7798class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
7801 LogicalResult matchAndRewrite(MaskOp maskOp,
7802 PatternRewriter &rewriter)
const override {
7803 if (!maskOp.isEmpty())
7806 if (!maskOp.hasPassthru())
7813 VectorType maskType = maskOp.getMask().getType();
7814 for (Type resultType : maskOp.getResultTypes()) {
7815 auto vecResultType = dyn_cast<VectorType>(resultType);
7816 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
7820 Block *block = maskOp.getMaskBlock();
7821 auto terminator = cast<vector::YieldOp>(block->
front());
7822 assert(terminator.getNumOperands() == 1 &&
7823 "expected one result when passthru is provided");
7826 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7827 terminator.getOperand(0), maskOp.getPassthru());
7833void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7834 MLIRContext *context) {
7835 results.
add<CanonializeEmptyMaskOp>(context);
7841Operation *MaskOp::getMaskableOp() {
7842 Block *block = getMaskBlock();
7846 return &block->
front();
7850bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
7856LogicalResult ScanOp::verify() {
7857 VectorType srcType = getSourceType();
7858 VectorType initialType = getInitialValueType();
7860 int64_t srcRank = srcType.getRank();
7861 int64_t reductionDim = getReductionDim();
7862 if (reductionDim >= srcRank)
7864 << reductionDim <<
" has to be less than " << srcRank;
7867 int64_t initialValueRank = initialType.getRank();
7868 if (initialValueRank != srcRank - 1)
7870 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7873 ArrayRef<int64_t> srcShape = srcType.getShape();
7874 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7875 SmallVector<int64_t> expectedShape;
7876 for (
int i = 0; i < srcRank; i++) {
7877 if (i != reductionDim)
7878 expectedShape.push_back(srcShape[i]);
7880 if (!llvm::equal(initialValueShapes, expectedShape)) {
7881 return emitOpError(
"incompatible input/initial value shapes");
7885 Type eltType = getDestType().getElementType();
7888 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7895 RewritePatternSet &patterns, PatternBenefit benefit) {
7897 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7898 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7899 StridedSliceConstantMaskFolder, TransposeFolder>(
7904 CombiningKind kind, Value v1, Value acc,
7905 arith::FastMathFlagsAttr fastmath,
7912 case CombiningKind::ADD:
7914 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
7915 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7916 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7918 llvm_unreachable(
"invalid value types for ADD reduction");
7920 case CombiningKind::AND:
7922 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
7924 case CombiningKind::MAXNUMF:
7925 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7926 "expected float values");
7927 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7929 case CombiningKind::MAXIMUMF:
7930 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7931 "expected float values");
7932 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7934 case CombiningKind::MINNUMF:
7935 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7936 "expected float values");
7937 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7939 case CombiningKind::MINIMUMF:
7940 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7941 "expected float values");
7942 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7944 case CombiningKind::MAXSI:
7946 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7948 case CombiningKind::MINSI:
7950 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7952 case CombiningKind::MAXUI:
7954 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7956 case CombiningKind::MINUI:
7958 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7960 case CombiningKind::MUL:
7962 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
7963 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7964 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7966 llvm_unreachable(
"invalid value types for MUL reduction");
7968 case CombiningKind::OR:
7970 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
7972 case CombiningKind::XOR:
7974 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7978 assert(
result &&
"unknown CombiningKind");
7986void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7988 auto resultType = cast<VectorType>(
getType());
7989 if (resultType.isScalable()) {
7993 APInt zero(bitwidth, 0);
7994 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7995 ConstantIntRanges
result = {zero, high, zero, high};
7996 setResultRanges(getResult(),
result);
8026struct StepCompareFolder :
public OpRewritePattern<StepOp> {
8029 LogicalResult matchAndRewrite(StepOp stepOp,
8030 PatternRewriter &rewriter)
const override {
8031 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8033 for (OpOperand &use : stepOp.getResult().getUses()) {
8034 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8039 const unsigned stepOperandNumber = use.getOperandNumber();
8040 if (stepOperandNumber != 0)
8044 unsigned constOperandNumber = 1;
8045 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8046 std::optional<int64_t> maybeConstValue =
8048 if (!maybeConstValue.has_value())
8051 int64_t constValue = maybeConstValue.value();
8052 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8054 auto maybeSplat = [&]() -> std::optional<bool> {
8056 if ((pred == arith::CmpIPredicate::ult ||
8057 pred == arith::CmpIPredicate::uge) &&
8058 stepSize <= constValue)
8059 return pred == arith::CmpIPredicate::ult;
8062 if ((pred == arith::CmpIPredicate::ule ||
8063 pred == arith::CmpIPredicate::ugt) &&
8064 stepSize - 1 <= constValue) {
8065 return pred == arith::CmpIPredicate::ule;
8069 if ((pred == arith::CmpIPredicate::eq ||
8070 pred == arith::CmpIPredicate::ne) &&
8071 stepSize <= constValue)
8072 return pred == arith::CmpIPredicate::ne;
8074 return std::nullopt;
8077 if (!maybeSplat.has_value())
8082 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8087 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8099void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8100 MLIRContext *context) {
8101 results.
add<StepCompareFolder>(context);
8111 Operation *maskableOp) {
8112 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
8124 Operation *maskableOp, Value mask,
8129 return MaskOp::create(builder, maskableOp->
getLoc(),
8132 return MaskOp::create(builder, maskableOp->
getLoc(),
8145 Value newValue, Value passthru) {
8149 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
8150 mask, newValue, passthru);
8157#define GET_ATTRDEF_CLASSES
8158#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8160#define GET_OP_CLASSES
8161#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)