44#include "llvm/ADT/ArrayRef.h"
45#include "llvm/ADT/Repeated.h"
46#include "llvm/ADT/STLExtras.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/SmallVectorExtras.h"
49#include "llvm/ADT/StringSet.h"
50#include "llvm/ADT/TypeSwitch.h"
51#include "llvm/Support/Casting.h"
57#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
59#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
80 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
82 for (
bool b : denseElts.getValues<
bool>())
85 else if (!
b && val <= 0)
99 auto shape = m.getType().getShape();
101 bool allFalse =
true;
102 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
103 if (maskIdx < dimSize)
116 auto maskOperands = m.getOperands();
117 for (
Value operand : maskOperands) {
118 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
120 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
133 vector::YieldOp::create(builder, loc);
139 switch (combiningKind) {
140 case CombiningKind::ADD:
141 case CombiningKind::MUL:
143 case CombiningKind::MINUI:
144 case CombiningKind::MINSI:
145 case CombiningKind::MAXUI:
146 case CombiningKind::MAXSI:
147 case CombiningKind::AND:
148 case CombiningKind::OR:
149 case CombiningKind::XOR:
151 case CombiningKind::MINNUMF:
152 case CombiningKind::MAXNUMF:
153 case CombiningKind::MINIMUMF:
154 case CombiningKind::MAXIMUMF:
155 return llvm::isa<FloatType>(elementType);
185 VectorType vectorType) {
186 unsigned elementVectorRank = 0;
187 VectorType elementVectorType =
188 llvm::dyn_cast<VectorType>(shapedType.getElementType());
189 if (elementVectorType)
190 elementVectorRank += elementVectorType.getRank();
191 return vectorType.getRank() - elementVectorRank;
195 VectorType vectorType) {
198 if (shapedType.getRank() == 0 &&
204 shapedType.getRank(),
206 shapedType.getContext());
213 vector::TransferReadOp read) {
214 auto readMask = read.getMask();
215 auto writeMask = write.getMask();
221 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
222 if (!couldBeSameSplat)
239 vector::TransferReadOp read) {
240 return !defWrite.hasOutOfBoundsDim() &&
241 defWrite.getIndices() == read.getIndices() &&
242 defWrite.getVectorType() == read.getVectorType() &&
243 defWrite.getPermutationMap() == read.getPermutationMap() &&
244 ((!defWrite.getMask() && !read.getMask()) ||
249 vector::TransferWriteOp priorWrite) {
250 return priorWrite.getIndices() == write.getIndices() &&
251 priorWrite.getMask() == write.getMask() &&
252 priorWrite.getVectorType() == write.getVectorType() &&
253 priorWrite.getPermutationMap() == write.getPermutationMap();
257 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
258 bool testDynamicValueUsingBounds) {
260 if (transferA.getVectorType() != transferB.getVectorType())
262 unsigned rankOffset = transferA.getLeadingShapedRank();
263 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
264 Value indexA = transferA.getIndices()[i];
265 Value indexB = transferB.getIndices()[i];
269 if (i < rankOffset) {
272 if (cstIndexA.has_value() && cstIndexB.has_value()) {
273 if (*cstIndexA != *cstIndexB)
277 if (testDynamicValueUsingBounds) {
280 FailureOr<uint64_t> delta =
282 if (succeeded(delta) && *delta != 0)
285 FailureOr<bool> testEqual =
287 if (succeeded(testEqual) && !testEqual.value())
293 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
294 if (cstIndexA.has_value() && cstIndexB.has_value()) {
295 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
296 if (distance >= vectorDim)
300 if (testDynamicValueUsingBounds) {
303 FailureOr<int64_t> delta =
305 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
308 FailureOr<int64_t> computeDelta =
310 if (succeeded(computeDelta)) {
311 if (std::abs(computeDelta.value()) >= vectorDim)
321 VectorTransferOpInterface transferB,
322 bool testDynamicValueUsingBounds) {
323 if (transferA.getBase() != transferB.getBase())
326 testDynamicValueUsingBounds);
336 for (
auto [posInDim, dimSize, offsetInDim] :
337 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
339 if (posInDim < dimSize + offsetInDim)
343 posInDim = offsetInDim;
353 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
355 assert(constOp &&
"Unexpected non-constant index");
356 return constOp.value();
366 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
367 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
368 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
378 llvm::transform(foldResults, std::back_inserter(values),
380 if (
auto attr = dyn_cast<Attribute>(foldResult))
382 builder, loc, cast<IntegerAttr>(attr).getInt())
385 return cast<Value>(foldResult);
398 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
400 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
410 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
411 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
412 if (intAttr.getType() != expectedType)
413 return IntegerAttr::get(expectedType, intAttr.getInt());
419 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
420 auto intType = dyn_cast<IntegerType>(expectedType);
424 APFloat floatVal = floatAttr.getValue();
425 APInt intVal = floatVal.bitcastToAPInt();
426 return IntegerAttr::get(expectedType, intVal);
465struct VectorInlinerInterface :
public DialectInlinerInterface {
466 using DialectInlinerInterface::DialectInlinerInterface;
475void VectorDialect::initialize() {
477#define GET_ATTRDEF_LIST
478#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
483#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
486 addInterfaces<VectorInlinerInterface>();
488 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
489 MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
491 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
492 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
494 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
496 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
497 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
498 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
509 return arith::ConstantOp::materialize(builder, value, type, loc);
525void vector::MultiDimReductionOp::build(
OpBuilder &builder,
528 CombiningKind kind) {
530 for (
const auto &en : llvm::enumerate(reductionMask))
532 reductionDims.push_back(en.index());
533 build(builder,
result, kind, source,
acc, reductionDims);
536OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
538 if (getReductionDims().empty())
543std::optional<SmallVector<int64_t, 4>>
544MultiDimReductionOp::getShapeForUnroll() {
545 return llvm::to_vector<4>(getSourceVectorType().
getShape());
548LogicalResult MultiDimReductionOp::verify() {
551 Type inferredReturnType;
552 auto sourceScalableDims = getSourceVectorType().getScalableDims();
553 for (
auto [dimIdx, dimSize] :
554 llvm::enumerate(getSourceVectorType().
getShape()))
555 if (!llvm::any_of(getReductionDims(),
556 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
557 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
559 targetShape.push_back(dimSize);
560 scalableDims.push_back(sourceScalableDims[dimIdx]);
563 if (targetShape.empty())
564 inferredReturnType = getSourceVectorType().getElementType();
566 inferredReturnType = VectorType::get(
567 targetShape, getSourceVectorType().
getElementType(), scalableDims);
568 if (
getType() != inferredReturnType)
570 <<
" is incompatible with source type "
571 << getSourceVectorType();
577Type MultiDimReductionOp::getExpectedMaskType() {
578 auto vecType = getSourceVectorType();
579 return VectorType::get(vecType.getShape(),
580 IntegerType::get(vecType.getContext(), 1),
581 vecType.getScalableDims());
590struct ElideUnitDimsInMultiDimReduction
594 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
595 PatternRewriter &rewriter)
const override {
596 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
597 for (
const auto &dim :
enumerate(shape)) {
598 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
603 OpBuilder::InsertionGuard guard(rewriter);
606 if (reductionOp.isMasked()) {
608 rootOp = reductionOp.getMaskingOp();
609 mask = reductionOp.getMaskingOp().getMask();
611 rootOp = reductionOp;
614 Location loc = reductionOp.getLoc();
615 Value acc = reductionOp.getAcc();
617 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
619 VectorType newMaskType =
620 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
621 dstVecType.getScalableDims());
622 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
624 cast = vector::ShapeCastOp::create(
625 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
630 mask = vector::ExtractOp::create(rewriter, loc, mask);
631 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
636 cast,
nullptr, mask);
643void MultiDimReductionOp::getCanonicalizationPatterns(
645 results.
add<ElideUnitDimsInMultiDimReduction>(context);
654 arith::FastMathFlags fastMathFlags) {
660 arith::FastMathFlags fastMathFlags) {
662 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
666LogicalResult ReductionOp::verify() {
668 int64_t rank = getSourceVectorType().getRank();
670 return emitOpError(
"unsupported reduction rank: ") << rank;
673 Type eltType = getDest().getType();
676 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
685Type ReductionOp::getExpectedMaskType() {
686 auto vecType = getSourceVectorType();
687 return VectorType::get(vecType.getShape(),
688 IntegerType::get(vecType.getContext(), 1),
689 vecType.getScalableDims());
696 case arith::AtomicRMWKind::addf:
697 case arith::AtomicRMWKind::addi:
698 return vector::ReductionOp::create(builder,
vector.getLoc(),
699 CombiningKind::ADD,
vector);
700 case arith::AtomicRMWKind::mulf:
701 case arith::AtomicRMWKind::muli:
702 return vector::ReductionOp::create(builder,
vector.getLoc(),
703 CombiningKind::MUL,
vector);
704 case arith::AtomicRMWKind::minimumf:
705 return vector::ReductionOp::create(builder,
vector.getLoc(),
706 CombiningKind::MINIMUMF,
vector);
707 case arith::AtomicRMWKind::mins:
708 return vector::ReductionOp::create(builder,
vector.getLoc(),
709 CombiningKind::MINSI,
vector);
710 case arith::AtomicRMWKind::minu:
711 return vector::ReductionOp::create(builder,
vector.getLoc(),
712 CombiningKind::MINUI,
vector);
713 case arith::AtomicRMWKind::maximumf:
714 return vector::ReductionOp::create(builder,
vector.getLoc(),
715 CombiningKind::MAXIMUMF,
vector);
716 case arith::AtomicRMWKind::maxs:
717 return vector::ReductionOp::create(builder,
vector.getLoc(),
718 CombiningKind::MAXSI,
vector);
719 case arith::AtomicRMWKind::maxu:
720 return vector::ReductionOp::create(builder,
vector.getLoc(),
721 CombiningKind::MAXUI,
vector);
722 case arith::AtomicRMWKind::andi:
723 return vector::ReductionOp::create(builder,
vector.getLoc(),
724 CombiningKind::AND,
vector);
725 case arith::AtomicRMWKind::ori:
726 return vector::ReductionOp::create(builder,
vector.getLoc(),
727 CombiningKind::OR,
vector);
728 case arith::AtomicRMWKind::minnumf:
729 return vector::ReductionOp::create(builder,
vector.getLoc(),
730 CombiningKind::MINNUMF,
vector);
731 case arith::AtomicRMWKind::maxnumf:
732 return vector::ReductionOp::create(builder,
vector.getLoc(),
733 CombiningKind::MAXNUMF,
vector);
734 case arith::AtomicRMWKind::xori:
735 return vector::ReductionOp::create(builder,
vector.getLoc(),
736 CombiningKind::XOR,
vector);
744std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
745 return llvm::to_vector<4>(getSourceVectorType().
getShape());
752 LogicalResult matchAndRewrite(ReductionOp reductionOp,
757 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
760 if (maskableOp.isMasked()) {
762 rootOp = maskableOp.getMaskingOp();
763 mask = maskableOp.getMaskingOp().getMask();
765 rootOp = reductionOp;
768 auto vectorType = reductionOp.getSourceVectorType();
769 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
772 Location loc = reductionOp.getLoc();
774 mask = ExtractOp::create(rewriter, loc, mask);
775 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
777 if (Value acc = reductionOp.getAcc())
780 reductionOp.getFastmathAttr(), mask);
790 results.
add<ElideSingleElementReduction>(context);
804 getIndexingMapsAttrName(
result.name),
808 getIteratorTypesAttrName(
result.name),
811 return IteratorTypeAttr::get(builder.getContext(), t);
820 ContractionOp::getDefaultKind());
826 ArrayAttr iteratorTypes, CombiningKind kind,
827 arith::FastMathFlags fastMathFlags) {
830 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
831 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
833 CombiningKindAttr::get(builder.
getContext(), kind));
834 if (fastMathFlags != arith::FastMathFlags::none)
836 getFastmathAttrName(
result.name),
837 arith::FastMathFlagsAttr::get(builder.
getContext(), fastMathFlags));
848 DictionaryAttr dictAttr;
862 result.attributes.append(dictAttr.getValue().begin(),
863 dictAttr.getValue().end());
869 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
870 result.attributes.get(getIteratorTypesAttrName(
result.name)));
871 if (!iteratorTypes) {
873 <<
"expected " << getIteratorTypesAttrName(
result.name)
874 <<
" array attribute";
879 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
880 auto maybeIteratorType = symbolizeIteratorType(s);
881 if (!maybeIteratorType.has_value())
882 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
884 iteratorTypeAttrs.push_back(
885 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
887 result.attributes.set(getIteratorTypesAttrName(
result.name),
890 if (!
result.attributes.get(getKindAttrName(
result.name))) {
892 getKindAttrName(
result.name),
893 CombiningKindAttr::get(
result.getContext(),
894 ContractionOp::getDefaultKind()));
896 if (masksInfo.empty())
898 if (masksInfo.size() != 2)
900 "expected zero or exactly 2 vector mask operands");
901 auto lhsType = llvm::cast<VectorType>(types[0]);
902 auto rhsType = llvm::cast<VectorType>(types[1]);
904 std::array<VectorType, 2> maskTypes = {
914 auto attrNames = getTraitAttrNames();
916 traitAttrsSet.insert_range(attrNames);
918 for (
auto attr : (*this)->getAttrs()) {
919 if (attr.getName() == getIteratorTypesAttrName()) {
921 llvm::cast<ArrayAttr>(attr.getValue())
922 .getAsValueRange<IteratorTypeAttr, IteratorType>();
928 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) ->
Attribute {
929 return StringAttr::get(
getContext(), stringifyIteratorType(t));
932 attrs.emplace_back(getIteratorTypesAttrName(),
933 ArrayAttr::get(
getContext(), iteratorTypeNames));
934 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
936 if (attr.getName() == getFastmathAttrName() &&
937 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
938 arith::FastMathFlags::none)
940 attrs.push_back(attr);
944 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
945 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
946 p << getRhs() <<
", " << getAcc();
949 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
954 const std::vector<std::pair<int64_t, int64_t>> &map) {
955 for (
auto &dimPair : map) {
956 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
957 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
958 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
965 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
967 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
968 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
971 for (
auto &dimPair : contractingDimMap) {
972 lhsContractingDimSet.insert(dimPair.first);
973 rhsContractingDimSet.insert(dimPair.second);
976 llvm::make_second_range(batchDimMap));
980 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
981 if (lhsContractingDimSet.count(i) > 0)
983 expectedResultDims.push_back(lhsType.getDimSize(i));
987 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
988 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
990 expectedResultDims.push_back(rhsType.getDimSize(i));
994 if (expectedResultDims.empty()) {
996 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
997 return op.emitOpError(
"invalid accumulator/result vector shape");
1000 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
1001 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
1002 if (!resVectorType || !accVectorType)
1003 return op.emitOpError(
"invalid accumulator/result vector shape");
1009 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1010 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1012 return op.emitOpError(
1013 "expected all dimensions to be either a LHS or a RHS dimension");
1016 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1017 VectorType v = pair.first;
1018 auto map = pair.second;
1019 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1020 unsigned pos = map.getDimPosition(idx);
1025 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1026 return op.emitOpError(
"expected all dimensions to get an extent as "
1027 "either a LHS or a RHS dimension");
1029 AffineMap resMap = op.getIndexingMapsArray()[2];
1034 assert(llvm::all_of(expectedMap.
getResults(),
1035 llvm::IsaPred<AffineConstantExpr>) &&
1036 "expected constant extent along all dimensions.");
1038 auto expectedShape =
1040 return cast<AffineConstantExpr>(e).getValue();
1043 VectorType::get(expectedShape, resVectorType.getElementType(),
1044 resVectorType.getScalableDims());
1045 if (resVectorType != expected || accVectorType != expected)
1046 return op.emitOpError(
1047 "invalid accumulator/result vector shape, expected: ")
1053LogicalResult ContractionOp::verify() {
1054 VectorType lhsType = getLhsType();
1055 VectorType rhsType = getRhsType();
1056 Type accType = getAccType();
1057 Type resType = getResultType();
1059 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1060 if (!lhsType.getElementType().isSignlessInteger())
1061 return emitOpError(
"only supports signless integer types");
1065 if (getIndexingMapsArray().size() != 3)
1066 return emitOpError(
"expected an indexing map for each vector operand");
1071 unsigned numIterators = getIteratorTypes().getValue().size();
1072 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1073 auto index = it.index();
1074 auto map = it.value();
1075 if (map.getNumSymbols() != 0)
1077 <<
index <<
" to have no symbols";
1078 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1079 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1082 if (map.getNumDims() != numIterators)
1084 <<
index <<
" to have " << numIterators <<
" number of inputs";
1085 if (map.getNumResults() != rank)
1087 <<
index <<
" to have " << rank <<
" number of outputs";
1088 if (!map.isProjectedPermutation())
1090 <<
index <<
" to be a projected permutation of its inputs";
1093 auto contractingDimMap = getContractingDimMap();
1094 auto batchDimMap = getBatchDimMap();
1097 if (contractingDimMap.empty())
1098 return emitOpError(
"expected at least one contracting dimension pair");
1101 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1102 return emitOpError(
"invalid contracting dimension map");
1106 return emitOpError(
"invalid batch dimension map");
1110 contractingDimMap, batchDimMap)))
1113 if (!getKindAttr()) {
1114 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1115 "'vector.kind<add>')");
1119 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1120 auto elementType = vectorType ? vectorType.getElementType() : resType;
1122 return emitOpError(
"unsupported contraction type");
1125 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1132Type ContractionOp::getExpectedMaskType() {
1133 auto indexingMaps = this->getIndexingMapsArray();
1136 VectorType lhsType = this->getLhsType();
1137 VectorType rhsType = this->getRhsType();
1139 unsigned numVecDims = lhsIdxMap.
getNumDims();
1145 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1148 lhsType.getScalableDims()[dimIdx];
1150 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1153 rhsType.getScalableDims()[dimIdx];
1156 assert(ShapedType::isStaticShape(maskShape) &&
1157 "Mask shape couldn't be computed");
1159 return VectorType::get(maskShape,
1160 IntegerType::get(lhsType.getContext(), 1),
1161 maskShapeScalableDims);
1166 getIteratorTypesAttrName(), getKindAttrName(),
1167 getFastmathAttrName()};
1177static std::vector<std::pair<int64_t, int64_t>>
1179 IteratorType targetIteratorType,
MLIRContext *context) {
1180 std::vector<std::pair<int64_t, int64_t>> dimMap;
1181 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1182 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1183 if (iteratorType != targetIteratorType)
1189 if (lhsDim >= 0 && rhsDim >= 0)
1190 dimMap.emplace_back(lhsDim, rhsDim);
1195void ContractionOp::getIterationBounds(
1197 auto lhsShape = getLhsType().getShape();
1198 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1200 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1203 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1204 if (iteratorType == IteratorType::reduction) {
1207 assert(lhsDimIndex >= 0);
1208 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1213 assert(resDimIndex >= 0);
1214 assert(resVectorType !=
nullptr);
1215 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1219void ContractionOp::getIterationIndexMap(
1221 unsigned numMaps = getIndexingMapsArray().size();
1222 iterationIndexMap.resize(numMaps);
1223 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1224 auto index = it.index();
1225 auto map = it.value();
1226 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1227 auto dim = cast<AffineDimExpr>(map.getResult(i));
1228 iterationIndexMap[
index][dim.getPosition()] = i;
1233std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1235 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1239std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1241 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1245std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1247 getIterationBounds(
shape);
1269template <
typename AddOpType>
1275 auto canonicalize = [&](
Value maybeContraction,
1276 Value otherOperand) -> vector::ContractionOp {
1277 vector::ContractionOp contractionOp =
1278 dyn_cast_or_null<vector::ContractionOp>(
1281 return vector::ContractionOp();
1282 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1283 contractionOp.getAcc().getDefiningOp())) {
1284 if (maybeZero.getValue() ==
1285 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1287 bvm.
map(contractionOp.getAcc(), otherOperand);
1288 auto newContraction =
1289 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1290 rewriter.
replaceOp(addOp, newContraction.getResult());
1291 return newContraction;
1294 return vector::ContractionOp();
1297 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1298 vector::ContractionOp
contract = canonicalize(a,
b);
1323 setResultRanges(getResult(), argRanges.front());
1328 auto vectorTy = cast<VectorType>(source.
getType());
1353 build(builder,
result, source, dynamicPos,
1358ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1359 ExtractOp::Adaptor adaptor,
1361 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1362 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1363 vectorType.getRank()) {
1364 inferredReturnTypes.push_back(vectorType.getElementType());
1366 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1367 vectorType.getRank());
1368 inferredReturnTypes.push_back(VectorType::get(
1369 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1370 vectorType.getScalableDims().drop_front(n)));
1375LogicalResult vector::ExtractOp::verify() {
1376 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1377 if (resTy.getRank() == 0)
1379 "expected a scalar instead of a 0-d vector as the result type");
1382 auto dynamicMarkersCount =
1383 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1384 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1386 "mismatch between dynamic and static positions (kDynamic marker but no "
1387 "corresponding dynamic position) -- this can only happen due to an "
1388 "incorrect fold/rewrite");
1389 auto position = getMixedPosition();
1390 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1392 "expected position attribute of rank no greater than vector rank");
1393 for (
auto [idx, pos] : llvm::enumerate(position)) {
1394 if (
auto attr = dyn_cast<Attribute>(pos)) {
1395 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1397 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1398 return emitOpError(
"expected position attribute #")
1400 <<
" to be a non-negative integer smaller than the "
1401 "corresponding vector dimension or poison (-1)";
1408template <
typename IntType>
1410 return llvm::map_to_vector<4>(
1411 arrayAttr.getAsRange<IntegerAttr>(),
1412 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1418 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1422 if (extractOp.hasDynamicPosition())
1426 ExtractOp currentOp = extractOp;
1428 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1429 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1432 if (currentOp.hasDynamicPosition())
1435 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1437 extractOp.setOperand(0, currentOp.getSource());
1440 std::reverse(globalPosition.begin(), globalPosition.end());
1441 extractOp.setStaticPosition(globalPosition);
1453class ExtractFromInsertTransposeChainState {
1455 ExtractFromInsertTransposeChainState(ExtractOp e);
1464 template <
typename ContainerA,
typename ContainerB>
1465 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1466 return a.size() <=
b.size() &&
1467 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1474 template <
typename ContainerA,
typename ContainerB>
1475 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1476 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1477 if (elemA < 0 || elemB < 0)
1488 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1492 void updateStateForNextIteration(Value v) {
1499 LogicalResult handleTransposeOp();
1502 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1517 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1522 Value tryToFoldExtractOpInPlace(Value source);
1524 ExtractOp extractOp;
1526 int64_t extractedRank;
1528 InsertOp nextInsertOp;
1529 TransposeOp nextTransposeOp;
1539 SmallVector<int64_t> sentinels;
1540 SmallVector<int64_t> extractPosition;
1544ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1546 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1547 extractedRank(extractOp.getNumIndices()) {
1548 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1549 sentinels.reserve(vectorRank - extractedRank);
1550 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1551 sentinels.push_back(-(i + 1));
1553 extractOp.getStaticPosition().end());
1559LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1561 if (extractOp.hasDynamicPosition())
1564 if (!nextTransposeOp)
1567 nextTransposeOp.getPermutation(), extractOp.getContext()));
1574ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1577 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1580 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1581 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1584 res = nextInsertOp.getValueToStore();
1593ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1595 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1598 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1608 res = nextInsertOp.getValueToStore();
1616Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1619 if (extractOp.hasDynamicPosition())
1623 bool nothingToFold = (source == extractOp.getSource());
1624 if (nothingToFold || !canFold())
1628 OpBuilder
b(extractOp.getContext());
1629 extractOp.setStaticPosition(
1631 extractOp.getSourceMutable().assign(source);
1632 return extractOp.getResult();
1636Value ExtractFromInsertTransposeChainState::fold() {
1638 if (extractOp.hasDynamicPosition())
1641 Value valueToExtractFrom = extractOp.getSource();
1642 updateStateForNextIteration(valueToExtractFrom);
1643 while (nextInsertOp || nextTransposeOp) {
1646 if (succeeded(handleTransposeOp())) {
1647 valueToExtractFrom = nextTransposeOp.getVector();
1648 updateStateForNextIteration(valueToExtractFrom);
1654 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1659 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1660 return tryToFoldExtractOpInPlace(
result);
1664 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1670 valueToExtractFrom = nextInsertOp.getDest();
1671 updateStateForNextIteration(valueToExtractFrom);
1674 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1679 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1680 auto vecType = dyn_cast<VectorType>(type);
1681 return vecType && vecType.getRank() == 0;
1691 if (isa<BroadcastOp>(op))
1694 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1702 VectorType srcType = shapeCast.getSourceVectorType();
1704 uint64_t srcRank = srcType.getRank();
1706 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1732 Operation *defOp = extractOp.getSource().getDefiningOp();
1739 if (extractOp.getType() == input.
getType())
1745 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1746 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1747 unsigned inputRank = inputType ? inputType.getRank() : 0;
1748 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1749 unsigned extractRank = extractType ? extractType.getRank() : 0;
1752 if (extractRank > inputRank)
1756 assert(inputType &&
"input must be a vector type because of previous checks");
1765 extractType.getShape() != inputShape.take_back(extractRank))
1770 unsigned deltaOverall = inputRank - extractRank;
1771 unsigned deltaBroadcast = broadcastRank - inputRank;
1775 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1776 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1779 extractOp->setOperands(
1780 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1781 extractOp.setStaticPosition(staticPos);
1782 return extractOp.getResult();
1798 if (extractOp.hasDynamicPosition())
1801 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1806 if (shuffleOp.getResultVectorType().getRank() != 1)
1809 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1810 auto shuffleMask = shuffleOp.getMask();
1811 int64_t extractIdx = extractOp.getStaticPosition()[0];
1812 int64_t shuffleIdx = shuffleMask[extractIdx];
1815 if (shuffleIdx < inputVecSize) {
1816 extractOp.setOperand(0, shuffleOp.getV1());
1817 extractOp.setStaticPosition({shuffleIdx});
1819 extractOp.setOperand(0, shuffleOp.getV2());
1820 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1823 return extractOp.getResult();
1829 if (extractOp.hasDynamicPosition())
1832 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1837 auto getDimReverse = [](VectorType type,
int64_t n) {
1838 return type.getShape().take_back(n + 1).front();
1841 llvm::isa<VectorType>(extractOp.getType())
1842 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1844 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1846 if (destinationRank > 0) {
1847 auto destinationType =
1848 llvm::cast<VectorType>(extractOp.getResult().getType());
1849 for (
int64_t i = 0; i < destinationRank; i++) {
1853 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1854 getDimReverse(destinationType, i))
1861 std::reverse(extractedPos.begin(), extractedPos.end());
1864 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1865 strides.push_back(stride);
1867 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1875 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1877 for (
int64_t i = 0; i < numDimension; i++) {
1878 newStrides.push_back(stride);
1880 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1882 std::reverse(newStrides.begin(), newStrides.end());
1886 extractOp.setStaticPosition(newPosition);
1887 extractOp.setOperand(0, shapeCastOp.getSource());
1888 return extractOp.getResult();
1894 if (extractOp.hasDynamicPosition())
1897 auto extractStridedSliceOp =
1898 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1899 if (!extractStridedSliceOp)
1908 if (extractStridedSliceOp.hasNonUnitStrides())
1914 while (!sliceOffsets.empty()) {
1915 size_t lastOffset = sliceOffsets.size() - 1;
1916 if (sliceOffsets.back() != 0 ||
1917 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1918 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1920 sliceOffsets.pop_back();
1922 unsigned destinationRank = 0;
1923 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1924 destinationRank = vecType.getRank();
1927 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1928 sliceOffsets.size())
1932 assert(extractedPos.size() >= sliceOffsets.size());
1933 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1934 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1935 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1939 extractOp.setStaticPosition(extractedPos);
1940 return extractOp.getResult();
1946 if (extractOp.hasDynamicPosition())
1950 llvm::isa<VectorType>(extractOp.getType())
1951 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1953 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1963 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1964 insertOp.getSourceVectorType().getRank();
1965 if (destinationRank > insertOp.getSourceVectorType().getRank())
1970 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1971 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1974 bool disjoint =
false;
1976 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1977 int64_t start = insertOffsets[dim];
1979 (dim < insertRankDiff)
1981 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1983 int64_t offset = extractOffsets[dim];
1985 if (start <= offset && offset < end) {
1986 if (dim >= insertRankDiff)
1987 offsetDiffs.push_back(offset - start);
1998 insertOp.getSourceVectorType().getRank() - destinationRank;
1999 for (
int64_t i = 0; i < destinationRank; i++) {
2000 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
2001 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2005 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2008 extractOp.setStaticPosition(offsetDiffs);
2009 return extractOp.getResult();
2013 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2026 if (extractOp.hasDynamicPosition())
2030 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2031 if (!fromElementsOp)
2035 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2036 if (vecType.isScalable())
2040 int64_t rank = vecType.getRank();
2042 if (extractOp.getType() != vecType.getElementType())
2045 "unexpected number of indices");
2050 for (
int i = rank - 1; i >= 0; --i) {
2051 flatIndex +=
indices[i] * stride;
2052 stride *= vecType.getDimSize(i);
2054 return fromElementsOp.getElements()[flatIndex];
2059template <
typename OpType,
typename AdaptorType>
2062 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2063 OperandRange dynamicPosition = op.getDynamicPosition();
2066 if constexpr (std::is_same_v<OpType, ExtractOp>)
2067 vectorShape = op.getSourceVectorType().getShape();
2072 if (!dynamicPosition.size())
2079 bool opChange =
false;
2080 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2081 if (ShapedType::isStatic(staticPosition[i]))
2085 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2086 int64_t value = attr.getInt();
2090 staticPosition[i] = attr.getInt();
2095 operands.push_back(position);
2099 op.setStaticPosition(staticPosition);
2100 op.getOperation()->setOperands(operands);
2102 return op.getResult();
2112 if (!is_contained(staticPos, poisonVal))
2115 return ub::PoisonAttr::get(context);
2129 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2134 if (denseAttr.isSplat()) {
2136 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2141 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2142 if (vecTy.isScalable())
2145 if (extractOp.hasDynamicPosition()) {
2160 copy(extractOp.getStaticPosition(), completePositions.begin());
2163 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2166 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2168 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2171 newAttr = *denseValuesBegin;
2177OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2181 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2188 SmallVector<Value> operands = {getSource()};
2192 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2198 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2213 return inplaceFolded;
2219class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2223 LogicalResult matchAndRewrite(ExtractOp extractOp,
2224 PatternRewriter &rewriter)
const override {
2227 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2233 BroadcastableToResult::Success)
2242class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2246 LogicalResult matchAndRewrite(ExtractOp extractOp,
2247 PatternRewriter &rewriter)
const override {
2249 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2253 VectorType extractedMaskType =
2254 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2256 if (!extractedMaskType)
2259 auto maskOperands = createMaskOp.getOperands();
2260 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2261 VectorType maskType = createMaskOp.getVectorType();
2263 bool containsUnknownDims =
false;
2266 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2268 int64_t pos = extractOpPos[dimIdx];
2269 Value operand = maskOperands[dimIdx];
2270 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2273 containsUnknownDims =
true;
2277 int64_t createMaskBound =
2278 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2280 if (pos != ShapedType::kDynamic) {
2283 allFalse |= pos >= createMaskBound;
2284 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2288 containsUnknownDims =
true;
2295 }
else if (!containsUnknownDims) {
2297 extractOp, extractedMaskType,
2298 maskOperands.drop_front(extractOpPos.size()));
2307class ExtractOpFromConstantMask final :
public OpRewritePattern<ExtractOp> {
2311 LogicalResult matchAndRewrite(ExtractOp extractOp,
2312 PatternRewriter &rewriter)
const override {
2313 auto constantMaskOp =
2314 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2315 if (!constantMaskOp)
2318 Type resultType = extractOp.getResult().getType();
2319 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2321 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2322 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2324 VectorType maskType = constantMaskOp.getVectorType();
2327 for (
size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2328 int64_t pos = extractOpPos[dimIdx];
2329 if (pos == ShapedType::kDynamic) {
2332 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2341 if (pos >= maskDimSizes[dimIdx]) {
2342 if (extractedMaskType) {
2354 if (extractedMaskType) {
2358 extractOp, extractedMaskType,
2359 maskDimSizes.drop_front(extractOpPos.size()));
2372LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2373 PatternRewriter &rewriter) {
2374 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2378 VectorType sourceType = castOp.getSourceVectorType();
2379 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2383 if (sourceType.getNumElements() != targetType.getNumElements())
2387 castOp.getSource());
2397LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2398 PatternRewriter &rewriter) {
2400 if (extractOp.hasDynamicPosition())
2404 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2409 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2410 if (!fromElementsOp)
2412 VectorType inputType = fromElementsOp.getType();
2415 if (resultType.isScalable() || inputType.isScalable())
2420 SmallVector<int64_t> firstElementPos =
2421 llvm::to_vector(extractOp.getStaticPosition());
2422 firstElementPos.append(resultType.getRank(), 0);
2425 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2426 flatIndex += firstElementPos[i] * stride;
2427 stride *= inputType.getDimSize(i);
2432 extractOp, resultType,
2433 fromElementsOp.getElements().slice(flatIndex,
2434 resultType.getNumElements()));
2446struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2448 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2449 PatternRewriter &rewriter)
const override {
2450 VectorType sourceType = extractOp.getSourceVectorType();
2451 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2455 if (sourceType.getNumElements() != outType.getNumElements())
2457 extractOp,
"extract to vector with fewer elements");
2461 if (llvm::any_of(extractOp.getMixedPosition(),
2462 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2464 "leaving for extract poison folder");
2467 extractOp.getSource());
2488struct FoldExtractFromInsertUnitDim final
2489 : OpRewritePattern<vector::ExtractOp> {
2492 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2493 PatternRewriter &rewriter)
const override {
2494 if (extractOp.hasDynamicPosition())
2497 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2498 if (!insertOp || insertOp.hasDynamicPosition())
2501 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2502 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2505 if (extractPos.size() >= insertPos.size() ||
2506 extractPos != insertPos.take_front(extractPos.size()))
2512 auto srcVecType = extractOp.getSourceVectorType();
2513 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2514 if (srcVecType.getDimSize(i) != 1)
2517 Value
inserted = insertOp.getValueToStore();
2518 Type extractedType = extractOp.getResult().getType();
2519 if (isa<VectorType>(
inserted.getType())) {
2526 extractOp, extractOp.getResult().
getType(),
2527 insertOp.getValueToStore());
2535void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2536 MLIRContext *context) {
2537 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2538 ExtractOpFromConstantMask, ExtractToShapeCast,
2539 FoldExtractFromInsertUnitDim>(context);
2540 results.
add(foldExtractFromShapeCastToShapeCast);
2541 results.
add(foldExtractFromFromElements);
2546 for (
auto attr : arrayAttr)
2547 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2554std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2565 if (operands.empty())
2568 return llvm::all_of(operands, [&](
Value operand) {
2570 return currentDef == defOp;
2588 auto fromElementsOp =
2589 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2590 if (!fromElementsOp)
2593 llvm::append_range(results, fromElementsOp.getElements());
2610 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2614 if (isa<VectorType>(bcastOp.getSource().getType()))
2617 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2619 Value scalar = bcastOp.getSource();
2620 results.assign(resultVecType.getNumElements(), scalar);
2624LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2625 SmallVectorImpl<OpFoldResult> &results) {
2630 if (
auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2631 setOperand(shapeCast.getSource());
2639ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2640 ToElementsOp::Adaptor adaptor,
2641 SmallVectorImpl<Type> &inferredReturnTypes) {
2642 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2643 Type elType = vecType.getElementType();
2644 inferredReturnTypes.append(vecType.getNumElements(), elType);
2666 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2671 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2675 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2680 int64_t dstRank = dstShape.size();
2681 int64_t srcRank = srcShape.size();
2684 auto srcElems = vector::ToElementsOp::create(
2685 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2687 int64_t dstCount = llvm::product_of(dstShape);
2690 replacements.reserve(dstCount);
2715 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2718 for (
int64_t k = 0; k < srcRank; ++k)
2719 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2722 replacements.push_back(srcElems.getResult(srcLin));
2725 rewriter.
replaceOp(toElementsOp, replacements);
2730void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2731 MLIRContext *context) {
2732 results.
add<ToElementsOfBroadcast>(context);
2752 OperandRange fromElemsOperands = fromElementsOp.getElements();
2753 if (fromElemsOperands.empty())
2756 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2764 Value toElementsInput = toElementsOp.getSource();
2765 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2766 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2767 return toElementsInput;
2787 if (llvm::any_of(elements, [](
Attribute attr) {
2793 auto destVecType = fromElementsOp.getDest().getType();
2794 auto destEltType = destVecType.getElementType();
2795 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2800 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2807OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2824 if (!llvm::all_equal(fromElementsOp.getElements()))
2827 fromElementsOp, fromElementsOp.getType(),
2828 fromElementsOp.getElements().front());
2856 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2860 if (fromElements.getType().getNumElements() == 1)
2871 for (
auto [insertIndex, element] :
2872 llvm::enumerate(fromElements.getElements())) {
2875 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2878 "element not from vector.extract");
2883 if (insertIndex == 0) {
2884 source = extractOp.getSource();
2885 }
else if (extractOp.getSource() != source) {
2887 "element from different vector");
2891 int64_t rank = position.size();
2892 assert(rank == source.getType().getRank() &&
2893 "scalar extract must have full rank position");
2904 if (insertIndex == 0) {
2905 const int64_t numElms = fromElements.getType().getNumElements();
2908 while (
index > 0 && position[
index - 1] == 0 &&
2909 numSuffixElms < numElms) {
2910 numSuffixElms *= source.getType().getDimSize(
index - 1);
2913 if (numSuffixElms != numElms) {
2915 fromElements,
"elements do not form a suffix of source");
2917 expectedPosition = llvm::to_vector(position);
2918 combinedPosition = position.drop_back(rank -
index);
2922 else if (expectedPosition != position) {
2924 fromElements,
"elements not in ascending order (static order)");
2926 increment(expectedPosition, source.getType().getShape());
2929 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2930 fromElements.getLoc(), source, combinedPosition);
2933 fromElements, fromElements.getType(), extracted);
2941 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2960void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2962 setResultRanges(getResult(), argRanges.front());
2965std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2966 return llvm::to_vector<4>(getResultVectorType().
getShape());
2971static llvm::SetVector<int64_t>
2974 int64_t rankDiff = dstShape.size() - srcShape.size();
2977 for (
auto [s1, s2] :
2978 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2980 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2988llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2990 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2993 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
3009Value BroadcastOp::createOrFoldBroadcastOp(
3010 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
3011 const llvm::SetVector<int64_t> &broadcastedDims) {
3012 assert(!dstShape.empty() &&
"unexpected empty dst shape");
3015 SmallVector<int64_t> checkShape;
3016 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
3017 if (broadcastedDims.contains(i))
3019 checkShape.push_back(dstShape[i]);
3021 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3022 "ill-formed broadcastedDims contains values not confined to "
3025 Location loc = value.
getLoc();
3027 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
3028 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3031 if (!srcVectorType) {
3032 assert(checkShape.empty() &&
3033 "ill-formed createOrFoldBroadcastOp arguments");
3034 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3037 assert(srcVectorType.getShape().equals(checkShape) &&
3038 "ill-formed createOrFoldBroadcastOp arguments");
3048 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3049 broadcastShape.reserve(dstShape.size());
3065 int64_t nextSrcShapeDim = broadcastedDims.size();
3066 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3067 if (broadcastedDims.contains(i)) {
3072 broadcastShape.push_back(dstShape[i]);
3073 permutation[i] = broadcastShape.size() - 1;
3079 permutation[i] = nextSrcShapeDim++;
3083 llvm::append_range(broadcastShape, srcVectorType.getShape());
3088 "unexpected \"dim-1\" broadcast");
3090 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3092 vector::BroadcastableToResult::Success &&
3093 "must be broadcastable");
3094 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3097 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3098 if (permutation[i] != i)
3099 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3105 Type srcType, VectorType dstVectorType,
3106 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3108 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3112 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3116 int64_t srcRank = srcVectorType.getRank();
3117 int64_t dstRank = dstVectorType.getRank();
3118 if (srcRank > dstRank)
3122 int64_t lead = dstRank - srcRank;
3123 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3126 bool foundMismatchingDims =
false;
3129 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3130 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3131 if (srcDim != 1 && srcDim != dstDim)
3132 foundMismatchingDims =
true;
3135 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3136 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3137 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3140 (srcDimScalableFlag != dstDimScalableFlag &&
3141 (srcDim != 1 || srcDimScalableFlag)))
3142 foundMismatchingDims =
true;
3144 if (foundMismatchingDims) {
3145 if (mismatchingDims !=
nullptr) {
3146 mismatchingDims->first.dim = srcDim;
3147 mismatchingDims->first.isScalable = srcDimScalableFlag;
3149 mismatchingDims->second.dim = dstDim;
3150 mismatchingDims->second.isScalable = dstDimScalableFlag;
3159LogicalResult BroadcastOp::verify() {
3160 std::pair<VectorDim, VectorDim> mismatchingDims;
3162 getSourceType(), getResultVectorType(), &mismatchingDims);
3166 return emitOpError(
"source rank higher than destination rank");
3169 << (mismatchingDims.first.isScalable ?
"[" :
"")
3170 << mismatchingDims.first.dim
3171 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
3172 << (mismatchingDims.second.isScalable ?
"[" :
"")
3173 << mismatchingDims.second.dim
3174 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3177 return emitOpError(
"source type is not a vector");
3178 llvm_unreachable(
"unexpected vector.broadcast op error");
3185 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3189 VectorType srcType = srcShapeCast.getSourceVectorType();
3190 VectorType destType = broadcastOp.getResultVectorType();
3198 srcShapeCast.getResultVectorType().getShape();
3201 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3202 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3203 shapecastShape.take_back(numTrailingDims)))
3206 assert(all_of(srcShape.drop_back(numTrailingDims),
3207 [](
int64_t E) { return E == 1; }) &&
3208 all_of(shapecastShape.drop_back(numTrailingDims),
3209 [](
int64_t E) { return E == 1; }) &&
3210 "ill-formed shape_cast");
3212 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3216OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3217 if (getSourceType() == getResultVectorType())
3222 if (!adaptor.getSource())
3224 auto vectorType = getResultVectorType();
3225 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3226 if (vectorType.getElementType() != attr.getType())
3230 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3231 if (vectorType.getElementType() != attr.getType())
3235 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3245struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3248 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3249 PatternRewriter &rewriter)
const override {
3250 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3254 broadcastOp.getResultVectorType(),
3255 srcBroadcast.getSource());
3268struct BroadcastToShapeCast final
3269 :
public OpRewritePattern<vector::BroadcastOp> {
3271 LogicalResult matchAndRewrite(vector::BroadcastOp
broadcast,
3272 PatternRewriter &rewriter)
const override {
3274 auto sourceType = dyn_cast<VectorType>(
broadcast.getSourceType());
3277 broadcast,
"source is a scalar, shape_cast doesn't support scalar");
3281 if (sourceType.getNumElements() != outType.getNumElements()) {
3283 broadcast,
"broadcast to a greater number of elements");
3293void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3294 MLIRContext *context) {
3295 results.
add<BroadcastFolder, BroadcastToShapeCast>(context);
3302LogicalResult ShuffleOp::verify() {
3303 VectorType resultType = getResultVectorType();
3304 VectorType v1Type = getV1VectorType();
3305 VectorType v2Type = getV2VectorType();
3307 int64_t resRank = resultType.getRank();
3308 int64_t v1Rank = v1Type.getRank();
3309 int64_t v2Rank = v2Type.getRank();
3310 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3311 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3312 if (!wellFormed0DCase && !wellFormedNDCase)
3316 for (int64_t r = 1; r < v1Rank; ++r) {
3317 int64_t resDim = resultType.getDimSize(r);
3318 int64_t v1Dim = v1Type.getDimSize(r);
3319 int64_t v2Dim = v2Type.getDimSize(r);
3320 if (resDim != v1Dim || v1Dim != v2Dim)
3324 ArrayRef<int64_t> mask = getMask();
3325 int64_t maskLength = mask.size();
3326 if (maskLength <= 0)
3328 if (maskLength != resultType.getDimSize(0))
3331 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3332 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3333 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3335 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3341ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3342 ShuffleOp::Adaptor adaptor,
3343 SmallVectorImpl<Type> &inferredReturnTypes) {
3344 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3348 auto v1Rank = v1Type.getRank();
3351 SmallVector<int64_t, 4> shape;
3352 shape.reserve(v1Rank);
3353 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3356 llvm::append_range(shape, v1Type.getShape().drop_front());
3357 inferredReturnTypes.push_back(
3358 VectorType::get(shape, v1Type.getElementType()));
3362template <
typename T>
3365 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3366 return value == expected++;
3373 auto v1Type = op.getV1VectorType();
3374 auto v2Type = op.getV2VectorType();
3375 auto mask = op.getMask();
3388 if (!isV1Poison && !isV2Poison)
3391 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3392 bool changed =
false;
3394 for (
int64_t &idx : newMask) {
3395 if (idx == ShuffleOp::kPoisonIndex)
3397 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3398 idx = ShuffleOp::kPoisonIndex;
3406 op.setMask(newMask);
3407 return op.getResult();
3416 return ub::PoisonAttr::get(context);
3423 auto v1Type = op.getV1VectorType();
3424 if (v1Type.getRank() != 1)
3436 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3439 v2Elements = to_vector(v2DenseAttr.getValues<
Attribute>());
3440 poisonElement = v2Elements[0];
3443 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3446 v1Elements = to_vector(v1DenseAttr.getValues<
Attribute>());
3447 poisonElement = v1Elements[0];
3452 int64_t v1Size = v1Type.getDimSize(0);
3453 for (
int64_t maskIdx : mask) {
3456 if (maskIdx == ShuffleOp::kPoisonIndex) {
3457 indexedElm = poisonElement;
3459 if (maskIdx < v1Size)
3460 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3462 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3465 results.push_back(indexedElm);
3471OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3472 auto v1Type = getV1VectorType();
3474 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3475 "Vector shuffle does not support scalable vectors");
3479 if (v1Type.getRank() == 0)
3487 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3488 if (!v1Attr || !v2Attr)
3503struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3506 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3507 PatternRewriter &rewriter)
const override {
3508 VectorType v1VectorType = shuffleOp.getV1VectorType();
3509 ArrayRef<int64_t> mask = shuffleOp.getMask();
3510 if (v1VectorType.getRank() > 0)
3512 if (mask.size() != 1)
3514 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3532static Value getScalarSplatSource(Value value) {
3538 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3545 if (isa<VectorType>(
broadcast.getSourceType()))
3553class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3557 LogicalResult matchAndRewrite(ShuffleOp op,
3558 PatternRewriter &rewriter)
const override {
3559 Value splat = getScalarSplatSource(op.getV1());
3560 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3570class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3574 LogicalResult matchAndRewrite(ShuffleOp op,
3575 PatternRewriter &rewriter)
const override {
3576 VectorType resultType = op.getResultVectorType();
3577 if (resultType.isScalable())
3579 op,
"ShuffleOp can't represent a scalable interleave");
3581 if (resultType.getRank() != 1)
3583 op,
"ShuffleOp can't represent an n-D interleave");
3585 VectorType sourceType = op.getV1VectorType();
3586 if (sourceType != op.getV2VectorType() ||
3587 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3589 op,
"ShuffleOp types don't match an interleave");
3592 ArrayRef<int64_t> shuffleMask = op.getMask();
3593 int64_t resultVectorSize = resultType.getNumElements();
3594 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3595 int64_t maskValueA = shuffleMask[i * 2];
3596 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3597 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3599 "ShuffleOp mask not interleaving");
3615class FoldUnusedShuffleOperand final :
public OpRewritePattern<ShuffleOp> {
3619 LogicalResult matchAndRewrite(ShuffleOp op,
3620 PatternRewriter &rewriter)
const override {
3622 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3623 return mask == ShuffleOp::kPoisonIndex;
3630 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3633 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3642 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3643 ? op.getV1VectorType().getDimSize(0)
3645 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3646 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3648 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3652 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3653 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3655 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3663void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3664 MLIRContext *context) {
3665 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3666 FoldUnusedShuffleOperand>(context);
3673void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3675 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3678void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3679 Value source, Value dest) {
3680 auto vectorTy = cast<VectorType>(dest.
getType());
3681 build(builder,
result, source, dest,
3682 SmallVector<int64_t>(vectorTy.getRank(), 0));
3685void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3686 Value source, Value dest, int64_t position) {
3687 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3690void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3691 Value source, Value dest, OpFoldResult position) {
3692 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3695void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3696 Value source, Value dest,
3697 ArrayRef<int64_t> position) {
3698 SmallVector<OpFoldResult> posVals;
3699 posVals.reserve(position.size());
3700 llvm::transform(position, std::back_inserter(posVals),
3702 build(builder,
result, source, dest, posVals);
3705void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3706 Value source, Value dest,
3707 ArrayRef<OpFoldResult> position) {
3708 SmallVector<int64_t> staticPos;
3709 SmallVector<Value> dynamicPos;
3711 build(builder,
result, source, dest, dynamicPos,
3715LogicalResult InsertOp::verify() {
3716 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3717 if (srcTy.getRank() == 0)
3719 "expected a scalar instead of a 0-d vector as the source operand");
3721 SmallVector<OpFoldResult> position = getMixedPosition();
3722 auto destVectorType = getDestVectorType();
3723 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3725 "expected position attribute of rank no greater than dest vector rank");
3726 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3727 if (srcVectorType &&
3728 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3729 static_cast<unsigned>(destVectorType.getRank())))
3730 return emitOpError(
"expected position attribute rank + source rank to "
3731 "match dest vector rank");
3732 if (!srcVectorType &&
3733 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3735 "expected position attribute rank to match the dest vector rank");
3736 for (
auto [idx, pos] : llvm::enumerate(position)) {
3737 if (
auto attr = dyn_cast<Attribute>(pos)) {
3738 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3740 destVectorType.getDimSize(idx))) {
3741 return emitOpError(
"expected position attribute #")
3743 <<
" to be a non-negative integer smaller than the "
3745 "dest vector dimension";
3758 assert(positions.size() <= completePositions.size() &&
3759 "positions size must be less than or equal to destTy rank");
3760 copy(positions, completePositions.begin());
3768class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3772 LogicalResult matchAndRewrite(InsertOp insertOp,
3773 PatternRewriter &rewriter)
const override {
3775 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3776 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3777 srcVecType.getNumElements())
3780 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3786class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3790 LogicalResult matchAndRewrite(InsertOp op,
3791 PatternRewriter &rewriter)
const override {
3793 Value splat = getScalarSplatSource(op.getValueToStore());
3794 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3822class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3825 LogicalResult matchAndRewrite(InsertOp op,
3826 PatternRewriter &rewriter)
const override {
3828 VectorType destTy = op.getDestVectorType();
3829 if (destTy.isScalable())
3832 for (Operation *user : op.getResult().getUsers())
3833 if (
auto insertOp = dyn_cast<InsertOp>(user))
3834 if (insertOp.getDest() == op.getResult())
3837 InsertOp currentOp = op;
3838 SmallVector<InsertOp> chainInsertOps;
3841 if (currentOp.hasDynamicPosition())
3844 chainInsertOps.push_back(currentOp);
3845 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3848 if (currentOp && !currentOp->hasOneUse())
3852 int64_t vectorSize = destTy.getNumElements();
3853 int64_t initializedCount = 0;
3854 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3855 SmallVector<int64_t> pendingInsertPos;
3856 SmallVector<int64_t> pendingInsertSize;
3857 SmallVector<Value> pendingInsertValues;
3859 for (
auto insertOp : chainInsertOps) {
3861 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3865 int64_t insertBeginPosition =
3870 int64_t insertSize = 1;
3871 if (
auto srcVectorType =
3872 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3873 insertSize = srcVectorType.getNumElements();
3875 assert(insertBeginPosition + insertSize <= vectorSize &&
3876 "insert would overflow the vector");
3878 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3879 insertBeginPosition + insertSize)) {
3880 if (initializedDestIdxs[index])
3882 initializedDestIdxs[index] =
true;
3888 pendingInsertPos.push_back(insertBeginPosition);
3889 pendingInsertSize.push_back(insertSize);
3890 pendingInsertValues.push_back(insertOp.getValueToStore());
3892 if (initializedCount == vectorSize)
3897 if (initializedCount != vectorSize)
3900 SmallVector<Value> elements(vectorSize);
3901 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3902 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3903 pendingInsertValues))) {
3904 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3906 if (!srcVectorType) {
3907 elements[insertBeginPosition] = valueToStore;
3911 Repeated<Type> elementToInsertTypes(insertSize,
3912 srcVectorType.getElementType());
3914 auto elementsToInsert = vector::ToElementsOp::create(
3915 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3916 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3917 elements[insertBeginPosition + linearIdx] =
3918 elementsToInsert.getResult(linearIdx);
3932 int64_t maxVectorSizeFoldThreshold) {
3933 if (insertOp.hasDynamicPosition())
3936 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3944 VectorType destTy = insertOp.getDestVectorType();
3945 if (destTy.isScalable())
3949 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3950 !insertOp->hasOneUse())
3955 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3962 Type destEltType = destTy.getElementType();
3966 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3967 for (
auto value : denseSource.getValues<
Attribute>())
3973 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3974 copy(insertedValues, allValues.begin() + insertBeginPosition);
3983 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3987 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3990 insertOp.
setOperand(1, destInsert.getDest());
3991 return insertOp.getResult();
3994void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3995 MLIRContext *context) {
3996 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3997 InsertChainFullyInitialized>(context);
4000OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
4003 constexpr int64_t vectorSizeFoldThreshold = 256;
4007 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
4008 return getValueToStore();
4012 SmallVector<Value> operands = {getValueToStore(), getDest()};
4018 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4021 *
this, adaptor.getValueToStore(), adaptor.getDest(),
4022 vectorSizeFoldThreshold)) {
4026 return inplaceFolded;
4033void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4034 Value source, Value dest,
4035 ArrayRef<int64_t> offsets,
4036 ArrayRef<int64_t> strides) {
4037 result.addOperands({source, dest});
4041 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
4043 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
4048template <
typename OpType>
4052 StringRef attrName) {
4053 if (arrayAttr.size() >
shape.size())
4054 return op.emitOpError(
"expected ")
4055 << attrName <<
" attribute of rank no greater than vector rank";
4062template <
typename OpType>
4066 bool halfOpen =
true) {
4067 for (
auto attr : arrayAttr) {
4068 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4072 if (val < min || val >= upper)
4073 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
4074 <<
min <<
", " << upper <<
")";
4082template <
typename OpType>
4087 for (
auto [
index, attrDimPair] :
4088 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
4089 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4093 if (val < min || val >=
max)
4094 return op.emitOpError(
"expected ")
4095 << attrName <<
" dimension " <<
index <<
" to be confined to ["
4096 <<
min <<
", " <<
max <<
")";
4106template <
typename OpType>
4111 assert(arrayAttr1.size() <=
shape.size());
4112 assert(arrayAttr2.size() <=
shape.size());
4113 for (
auto [
index, it] :
4114 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
4115 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4116 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4120 if (val1 + val2 < 0 || val1 + val2 >=
max)
4121 return op.emitOpError(
"expected sum(")
4122 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
4123 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
4131 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4133 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4136LogicalResult InsertStridedSliceOp::verify() {
4137 auto sourceVectorType = getSourceVectorType();
4138 auto destVectorType = getDestVectorType();
4139 auto offsets = getOffsetsAttr();
4140 auto strides = getStridesAttr();
4141 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
4143 "expected offsets of same size as destination vector rank");
4144 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
4145 return emitOpError(
"expected strides of same size as source vector rank");
4146 if (sourceVectorType.getRank() > destVectorType.getRank())
4148 "expected source rank to be no greater than destination rank");
4150 auto sourceShape = sourceVectorType.getShape();
4151 auto destShape = destVectorType.getShape();
4152 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4153 destShape.size() - sourceShape.size(), 0);
4154 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4155 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4156 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4165 offName,
"source vector shape",
4169 unsigned rankDiff = destShape.size() - sourceShape.size();
4170 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4171 if (sourceVectorType.getScalableDims()[idx] !=
4172 destVectorType.getScalableDims()[idx + rankDiff]) {
4173 return emitOpError(
"mismatching scalable flags (at source vector idx=")
4176 if (sourceVectorType.getScalableDims()[idx]) {
4177 auto sourceSize = sourceShape[idx];
4178 auto destSize = destShape[idx + rankDiff];
4179 if (sourceSize != destSize) {
4182 << (
" to match the corresponding base size from the input "
4184 << sourceSize << (
" vs ") << destSize << (
")");
4194class FoldInsertStridedSliceSplat final
4195 :
public OpRewritePattern<InsertStridedSliceOp> {
4199 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4200 PatternRewriter &rewriter)
const override {
4202 auto dst = insertStridedSliceOp.getDest();
4203 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4204 if (!splat || getScalarSplatSource(dst) != splat)
4207 rewriter.
replaceOp(insertStridedSliceOp, dst);
4214class FoldInsertStridedSliceOfExtract final
4215 :
public OpRewritePattern<InsertStridedSliceOp> {
4219 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4220 PatternRewriter &rewriter)
const override {
4221 auto extractStridedSliceOp =
4222 insertStridedSliceOp.getValueToStore()
4223 .getDefiningOp<vector::ExtractStridedSliceOp>();
4225 if (!extractStridedSliceOp)
4228 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4232 if (extractStridedSliceOp.getStrides() !=
4233 insertStridedSliceOp.getStrides() ||
4234 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4237 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4244class InsertStridedSliceConstantFolder final
4245 :
public OpRewritePattern<InsertStridedSliceOp> {
4251 static constexpr int64_t vectorSizeFoldThreshold = 256;
4253 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4254 PatternRewriter &rewriter)
const override {
4258 Attribute vectorDestCst;
4262 VectorType destTy = destVector.getType();
4263 if (destTy.isScalable())
4267 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4268 !destVector.hasOneUse())
4272 Attribute sourceCst;
4282 if (op.hasNonUnitStrides())
4285 VectorType sliceVecTy = sourceValue.getType();
4286 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4287 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4288 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
4289 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
4297 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4298 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4299 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4300 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4301 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4302 MutableArrayRef<int64_t> currSlicePosition(
4303 currDestPosition.begin() + rankDifference, currDestPosition.end());
4304 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4307 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
4308 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
4309 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4310 "Invalid slice element");
4311 newValues[linearizedPosition] = *sliceValuesIt;
4324void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4325 RewritePatternSet &results, MLIRContext *context) {
4326 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4327 InsertStridedSliceConstantFolder>(context);
4330OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4331 if (getSourceVectorType() == getDestVectorType())
4332 return getValueToStore();
4341void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4342 Value
lhs, Value
rhs, Value acc) {
4347void OuterProductOp::print(OpAsmPrinter &p) {
4348 p <<
" " << getLhs() <<
", " << getRhs();
4350 p <<
", " << getAcc();
4353 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4356ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4357 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4364 if (operandsInfo.size() < 2)
4366 "expected at least 2 operands");
4367 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4368 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4371 "expected vector type for operand #1");
4375 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4376 vRHS.getScalableDims()[0]};
4377 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4378 vLHS.getElementType(), scalableDimsRes);
4381 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4382 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4386 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4387 result.attributes.append(
4388 OuterProductOp::getKindAttrName(
result.name),
4389 CombiningKindAttr::get(
result.getContext(),
4390 OuterProductOp::getDefaultKind()));
4396 (operandsInfo.size() > 2 &&
4401LogicalResult OuterProductOp::verify() {
4402 Type tRHS = getOperandTypeRHS();
4403 VectorType vLHS = getOperandVectorTypeLHS(),
4404 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4405 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4407 if (vLHS.getRank() != 1)
4408 return emitOpError(
"expected 1-d vector for operand #1");
4412 if (vRHS.getRank() != 1)
4413 return emitOpError(
"expected 1-d vector for operand #2");
4414 if (vRES.getRank() != 2)
4416 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4417 return emitOpError(
"expected #1 operand dim to match result dim #1");
4418 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4419 return emitOpError(
"expected #2 operand dim to match result dim #2");
4420 if (vLHS.isScalable() && !vRHS.isScalable()) {
4424 "expected either both or only #2 operand dim to be scalable");
4428 if (vRES.getRank() != 1)
4430 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4431 return emitOpError(
"expected #1 operand dim to match result dim #1");
4434 if (vACC && vACC != vRES)
4435 return emitOpError(
"expected operand #3 of same type as result type");
4437 if (!getKindAttr()) {
4438 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4439 "'vector.kind<add>')");
4444 return emitOpError(
"unsupported outerproduct type");
4453Type OuterProductOp::getExpectedMaskType() {
4454 auto vecType = this->getResultVectorType();
4455 return VectorType::get(vecType.getShape(),
4456 IntegerType::get(vecType.getContext(), 1),
4457 vecType.getScalableDims());
4471 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4473 shape.reserve(vectorType.getRank());
4475 for (
unsigned e = offsets.size(); idx < e; ++idx)
4476 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4477 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4478 shape.push_back(vectorType.getShape()[idx]);
4480 return VectorType::get(
shape, vectorType.getElementType(),
4481 vectorType.getScalableDims());
4484void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4485 Value source, ArrayRef<int64_t> offsets,
4486 ArrayRef<int64_t> sizes,
4487 ArrayRef<int64_t> strides) {
4488 result.addOperands(source);
4494 offsetsAttr, sizesAttr, stridesAttr));
4495 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4497 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4499 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4503LogicalResult ExtractStridedSliceOp::verify() {
4504 auto type = getSourceVectorType();
4505 auto offsets = getOffsetsAttr();
4506 auto sizes = getSizesAttr();
4507 auto strides = getStridesAttr();
4508 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4510 "expected offsets, sizes and strides attributes of same size");
4512 auto shape = type.getShape();
4513 auto offName = getOffsetsAttrName();
4514 auto sizesName = getSizesAttrName();
4515 auto stridesName = getStridesAttrName();
4531 shape, offName, sizesName,
4536 offsets, sizes, strides);
4537 if (getResult().
getType() != resultType)
4538 return emitOpError(
"expected result type to be ") << resultType;
4540 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4541 if (type.getScalableDims()[idx]) {
4542 auto inputDim = type.getShape()[idx];
4543 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4544 if (inputDim != inputSize)
4547 << (
" to match the corresponding base size from the input "
4549 << inputSize << (
" vs ") << inputDim << (
")");
4562 auto getElement = [](
ArrayAttr array,
int idx) {
4563 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4565 ArrayAttr extractOffsets = op.getOffsets();
4568 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4570 if (op.getSourceVectorType().getRank() !=
4571 insertOp.getSourceVectorType().getRank())
4573 ArrayAttr insertOffsets = insertOp.getOffsets();
4574 ArrayAttr insertStrides = insertOp.getStrides();
4577 if (extractOffsets.size() > insertOffsets.size())
4579 bool patialoverlap =
false;
4580 bool disjoint =
false;
4582 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4583 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4585 int64_t start = getElement(insertOffsets, dim);
4586 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4587 int64_t offset = getElement(extractOffsets, dim);
4588 int64_t size = getElement(extractSizes, dim);
4590 if (start <= offset && offset < end) {
4593 if (offset + size > end)
4594 patialoverlap =
true;
4595 offsetDiffs.push_back(offset - start);
4602 if (!disjoint && !patialoverlap) {
4603 op.setOperand(insertOp.getValueToStore());
4606 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4612 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4627 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4632 if (op.hasNonUnitStrides())
4635 VectorType sourceVecTy = op.getSourceVectorType();
4639 VectorType sliceVecTy = op.getType();
4641 int64_t rank = sliceVecTy.getRank();
4653 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4655 sliceValues.reserve(sliceVecTy.getNumElements());
4659 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4661 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4662 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4664 assert(
static_cast<int64_t>(sliceValues.size()) ==
4665 sliceVecTy.getNumElements() &&
4666 "Invalid number of slice elements");
4670OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4671 if (getSourceVectorType() == getResult().
getType())
4678 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4685void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4707class StridedSliceFolder final
4708 :
public OpRewritePattern<ExtractStridedSliceOp> {
4710 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4712 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4713 PatternRewriter &rewriter)
const override {
4714 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4718 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4721 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4722 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4723 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4724 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4726 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4727 SmallVector<int64_t> combinedOffsets(newRank, 0);
4728 SmallVector<int64_t> combinedSizes(newRank);
4729 ArrayRef<int64_t> firstSourceShape =
4730 firstOp.getSourceVectorType().getShape();
4731 for (
unsigned i = 0; i < newRank; ++i) {
4732 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4733 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4734 combinedOffsets[i] = off1 + off2;
4736 if (i < secondSizes.size()) {
4737 combinedSizes[i] = secondSizes[i];
4738 }
else if (i < firstSizes.size()) {
4739 combinedSizes[i] = firstSizes[i];
4741 combinedSizes[i] = firstSourceShape[i];
4745 SmallVector<int64_t> combinedStrides(newRank, 1);
4747 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4765class StridedSliceCreateMaskFolder final
4766 :
public OpRewritePattern<ExtractStridedSliceOp> {
4770 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4771 PatternRewriter &rewriter)
const override {
4772 Location loc = extractStridedSliceOp.getLoc();
4776 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4780 if (extractStridedSliceOp.hasNonUnitStrides())
4783 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4785 SmallVector<int64_t> sliceOffsets;
4788 SmallVector<int64_t> sliceSizes;
4792 SmallVector<Value> sliceMaskDimSizes;
4793 sliceMaskDimSizes.reserve(maskDimSizes.size());
4797 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4798 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4802 IntegerAttr offsetAttr =
4804 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4805 Value sliceMaskDimSize =
4806 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4807 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4812 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4816 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4824class StridedSliceConstantMaskFolder final
4825 :
public OpRewritePattern<ExtractStridedSliceOp> {
4829 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4830 PatternRewriter &rewriter)
const override {
4833 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4834 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4835 if (!constantMaskOp)
4838 if (extractStridedSliceOp.hasNonUnitStrides())
4841 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4843 SmallVector<int64_t> sliceOffsets;
4846 SmallVector<int64_t> sliceSizes;
4850 SmallVector<int64_t> sliceMaskDimSizes;
4851 sliceMaskDimSizes.reserve(maskDimSizes.size());
4852 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4853 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4854 int64_t sliceMaskDimSize = std::max(
4855 static_cast<int64_t
>(0),
4856 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4857 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4860 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4861 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4862 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4865 if (llvm::is_contained(sliceMaskDimSizes, 0))
4866 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4871 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4879class StridedSliceBroadcast final
4880 :
public OpRewritePattern<ExtractStridedSliceOp> {
4884 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4885 PatternRewriter &rewriter)
const override {
4891 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4892 auto dstVecType = llvm::cast<VectorType>(op.getType());
4893 unsigned dstRank = dstVecType.getRank();
4894 unsigned rankDiff = dstRank - srcRank;
4898 bool needsSlice =
false;
4899 for (
unsigned i = 0; i < srcRank; i++) {
4900 if (srcVecType.getDimSize(i) != 1 &&
4901 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4908 SmallVector<int64_t> offsets =
4910 SmallVector<int64_t> sizes =
4912 for (
unsigned i = 0; i < srcRank; i++) {
4913 if (srcVecType.getDimSize(i) == 1) {
4921 source = ExtractStridedSliceOp::create(
4922 rewriter, op->getLoc(), source, offsets, sizes,
4931class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4935 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4936 PatternRewriter &rewriter)
const override {
4938 Value splat = getScalarSplatSource(op.getSource());
4962class ContiguousExtractStridedSliceToExtract final
4963 :
public OpRewritePattern<ExtractStridedSliceOp> {
4967 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4968 PatternRewriter &rewriter)
const override {
4969 if (op.hasNonUnitStrides())
4971 Value source = op.getOperand();
4972 auto sourceType = cast<VectorType>(source.
getType());
4973 if (sourceType.isScalable() || sourceType.getRank() == 0)
4982 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4983 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4990 if (numOffsets == 0)
4995 if (numOffsets == sourceType.getRank() &&
4996 static_cast<int>(sizes.size()) == sourceType.getRank())
5000 for (
int i = 0; i < numOffsets; ++i) {
5008 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
5009 sizes[numOffsets] == 1) {
5014 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5015 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5024void ExtractStridedSliceOp::getCanonicalizationPatterns(
5025 RewritePatternSet &results, MLIRContext *context) {
5028 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5029 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5030 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5040void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5041 VectorType vectorType, Value source,
5043 AffineMapAttr permutationMapAttr,
5046 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5048 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5051 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
5052 *padding, Value(), inBoundsAttr);
5060void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5061 VectorType vectorType, Value source,
5063 AffineMap permutationMap,
5064 std::optional<ArrayRef<bool>> inBounds) {
5065 if (!permutationMap)
5067 llvm::cast<ShapedType>(source.
getType()), vectorType);
5068 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5069 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5072 SmallVector<bool>(vectorType.getRank(),
false));
5074 build(builder,
result, vectorType, source,
indices, padding,
5075 permutationMapAttr, inBoundsAttr);
5081void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5082 VectorType vectorType, Value source,
5084 std::optional<ArrayRef<bool>> inBounds) {
5086 build(builder,
result, vectorType, source,
indices, padding,
5087 AffineMap(), inBounds);
5090template <
typename EmitFun>
5094 for (
auto expr : permutationMap.
getResults()) {
5095 auto dim = dyn_cast<AffineDimExpr>(expr);
5096 auto zero = dyn_cast<AffineConstantExpr>(expr);
5098 if (zero.getValue() != 0) {
5100 "requires a projected permutation_map (at most one dim or the zero "
5101 "constant can appear in each result)");
5106 return emitOpError(
"requires a projected permutation_map (at most one "
5107 "dim or the zero constant can appear in each result)");
5109 if (seen[dim.getPosition()]) {
5111 "requires a permutation_map that is a permutation (found one dim "
5112 "used more than once)");
5114 seen[dim.getPosition()] =
true;
5121 VectorType vectorType, VectorType maskType,
5122 VectorType inferredMaskType,
AffineMap permutationMap,
5124 if (op->hasAttr(
"masked")) {
5125 return op->emitOpError(
"masked attribute has been removed. "
5126 "Use in_bounds instead.");
5129 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5130 return op->emitOpError(
5131 "requires source to be a memref or ranked tensor type");
5133 auto elementType = shapedType.getElementType();
5135 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5137 unsigned sourceVecSize =
5139 vectorElementType.getShape().back();
5140 unsigned resultVecSize =
5142 vectorType.getShape().back();
5143 if (resultVecSize % sourceVecSize != 0)
5144 return op->emitOpError(
5145 "requires the bitwidth of the minor 1-D vector to be an integral "
5146 "multiple of the bitwidth of the minor 1-D vector of the source");
5148 unsigned sourceVecEltRank = vectorElementType.getRank();
5149 unsigned resultVecRank = vectorType.getRank();
5150 if (sourceVecEltRank > resultVecRank)
5151 return op->emitOpError(
5152 "requires source vector element and vector result ranks to match.");
5153 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5156 return op->emitOpError(
"requires a permutation_map with result dims of "
5157 "the same rank as the vector type");
5160 return op->emitOpError(
"does not support masks with vector element type");
5163 unsigned minorSize =
5164 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5165 unsigned resultVecSize =
5168 return op->emitOpError(
5169 "requires the bitwidth of the minor 1-D vector to be an integral "
5170 "multiple of the bitwidth of the source element type");
5174 return op->emitOpError(
"requires a permutation_map with result dims of "
5175 "the same rank as the vector type");
5179 return op->emitOpError(
"requires permutation_map without symbols");
5181 if (permutationMap.
getNumInputs() != shapedType.getRank())
5182 return op->emitOpError(
"requires a permutation_map with input dims of the "
5183 "same rank as the source type");
5185 if (maskType && maskType != inferredMaskType)
5186 return op->emitOpError(
"inferred mask type (")
5187 << inferredMaskType <<
") and mask operand type (" << maskType
5191 return op->emitOpError(
"expects the in_bounds attr of same rank "
5192 "as permutation_map results: ")
5193 << AffineMapAttr::get(permutationMap)
5194 <<
" vs inBounds of size: " << inBounds.size();
5201 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5202 if (op.getPermutationMap().isMinorIdentity())
5203 elidedAttrs.push_back(op.getPermutationMapAttrName());
5205 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
5206 elidedAttrs.push_back(op.getInBoundsAttrName());
5210void TransferReadOp::print(OpAsmPrinter &p) {
5213 p <<
", " << getMask();
5220 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
5222 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
5227 if (maskShape.empty())
5228 maskShape.push_back(1);
5233 return VectorType::get(maskShape, i1Type, scalableDims);
5250 if (hasMask.succeeded()) {
5257 if (types.size() != 2)
5258 return parser.
emitError(typesLoc,
"requires two types");
5260 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5261 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5262 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5263 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5265 return parser.
emitError(typesLoc,
"requires vector type");
5266 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
5270 if (shapedType.getRank() <
5273 "expected a custom permutation_map when "
5274 "rank(source) != rank(destination)");
5276 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5278 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5280 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
5281 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5282 if (!inBoundsAttr) {
5283 result.addAttribute(inBoundsAttrName,
5292 if (hasMask.succeeded()) {
5293 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5295 maskInfo.
location,
"does not support masks with vector element type");
5298 "expected the same rank for the vector and the "
5299 "results of the permutation map");
5307 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5309 {1, static_cast<int32_t>(indexInfo.size()), 1,
5310 static_cast<int32_t>(hasMask.succeeded())}));
5314LogicalResult TransferReadOp::verify() {
5316 ShapedType shapedType = getShapedType();
5318 VectorType maskType = getMaskType();
5319 auto paddingType = getPadding().getType();
5320 auto permutationMap = getPermutationMap();
5321 VectorType inferredMaskType =
5324 auto sourceElementType = shapedType.getElementType();
5326 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5327 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5330 shapedType, vectorType, maskType,
5331 inferredMaskType, permutationMap, getInBounds())))
5334 if (
auto sourceVectorElementType =
5335 llvm::dyn_cast<VectorType>(sourceElementType)) {
5338 if (sourceVectorElementType != paddingType)
5340 "requires source element type and padding type to match.");
5344 if (!VectorType::isValidElementType(paddingType))
5345 return emitOpError(
"requires valid padding vector elemental type");
5348 if (paddingType != sourceElementType)
5350 "requires formal padding and source of the same elemental type");
5361Type TransferReadOp::getExpectedMaskType() {
5368VectorType TransferReadOp::getVectorType() {
5369 return cast<VectorType>(getVector().
getType());
5372template <
typename TransferOp>
5376 if (op.getShapedType().isDynamicDim(indicesIdx))
5380 if (!cstOp.has_value())
5383 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5384 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5386 return cstOp.value() + vectorSize <= sourceSize;
5389template <
typename TransferOp>
5393 if (op.getTransferRank() == 0)
5396 bool changed =
false;
5398 newInBounds.reserve(op.getTransferRank());
5403 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5405 if (op.isDimInBounds(i)) {
5406 newInBounds.push_back(
true);
5411 bool inBounds =
false;
5412 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5415 dimExpr.getPosition());
5416 nonBcastDims.push_back(i);
5419 newInBounds.push_back(inBounds);
5421 changed |= inBounds;
5427 bool allNonBcastDimsInBounds = llvm::all_of(
5428 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5429 if (allNonBcastDimsInBounds) {
5431 changed |= !newInBounds[idx];
5432 newInBounds[idx] =
true;
5440 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5444template <
typename TransferOp>
5446 auto mask = op.getMask();
5453 op.getMaskMutable().clear();
5461template <
typename TransferOp>
5463 VectorType vecType = op.getVectorType();
5464 if (vecType.getRank() != 1 || vecType.getShape()[0] != 1 ||
5465 vecType.isScalable())
5472 int64_t srcRank = op.getShapedType().getRank();
5478 op.setPermutationMapAttr(AffineMapAttr::get(minorIdentity));
5492static Value foldRAW(TransferReadOp readOp) {
5493 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5495 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5498 return defWrite.getVector();
5500 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5501 cast<VectorTransferOpInterface>(readOp.getOperation())))
5503 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5508OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5509 if (Value vec = foldRAW(*
this))
5522 return OpFoldResult();
5525std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5529void TransferReadOp::getEffects(
5530 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5532 if (llvm::isa<MemRefType>(getShapedType()))
5533 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5534 SideEffects::DefaultResource::get());
5538 if (hasPureTensorSemantics())
5545static AffineMap inverseWithUnusedDims(AffineMap map) {
5547 "expected a projected permutation map");
5552 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5582struct TransferReadAfterWriteToBroadcast
5583 :
public OpRewritePattern<TransferReadOp> {
5586 LogicalResult matchAndRewrite(TransferReadOp readOp,
5587 PatternRewriter &rewriter)
const override {
5588 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5592 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5596 if (readOp.getMask() || defWrite.getMask())
5599 if (readOp.getIndices() != defWrite.getIndices())
5602 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5606 if (readOp.getTransferChunkAccessed() !=
5607 defWrite.getTransferChunkAccessed())
5614 AffineMap readMap = readOp.getPermutationMap();
5615 AffineMap writeMap = defWrite.getPermutationMap();
5616 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5617 AffineMap composedMap = readMap.
compose(invWriteMap);
5631 int64_t numBroadcastedDims = broadcastedDims.size();
5632 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5634 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5635 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5636 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5637 invPerm[effectiveDim] = idx;
5642 VectorType readVecTy = readOp.getVectorType();
5644 auto broadcastedVecTy =
5646 readVecTy.getElementType(),
5649 Value vec = defWrite.getVector();
5650 Location loc = readOp.getLoc();
5651 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5658void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5659 MLIRContext *context) {
5660 results.
add<TransferReadAfterWriteToBroadcast>(context);
5663FailureOr<std::optional<SmallVector<Value>>>
5664TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5665 if (!hasPureBufferSemantics())
5676void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5678 AffineMapAttr permutationMapAttr,
5681 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5682 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5683 mask, inBoundsAttr);
5687void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5689 AffineMapAttr permutationMapAttr,
5691 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5692 Value(), inBoundsAttr);
5697void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5699 AffineMap permutationMap,
5700 std::optional<ArrayRef<bool>> inBounds) {
5701 if (!permutationMap)
5704 llvm::cast<VectorType>(vector.
getType()));
5705 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5707 (inBounds && !inBounds.value().empty())
5710 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5711 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5712 Value(), inBoundsAttr);
5717void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5719 std::optional<ArrayRef<bool>> inBounds) {
5724ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5725 OperationState &
result) {
5728 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5729 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5730 SmallVector<Type, 2> types;
5731 OpAsmParser::UnresolvedOperand maskInfo;
5737 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5742 if (types.size() != 2)
5743 return parser.
emitError(typesLoc,
"requires two types");
5745 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5747 return parser.
emitError(typesLoc,
"requires vector type");
5748 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5749 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5750 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5751 auto permMapAttrName =
5752 TransferWriteOp::getPermutationMapAttrName(
result.name);
5753 auto permMapAttr =
result.attributes.get(permMapAttrName);
5756 if (shapedType.getRank() <
5759 "expected a custom permutation_map when "
5760 "rank(source) != rank(destination)");
5762 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5764 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5766 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5767 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5768 if (!inBoundsAttr) {
5769 result.addAttribute(inBoundsAttrName,
5777 if (hasMask.succeeded()) {
5778 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5780 maskInfo.
location,
"does not support masks with vector element type");
5783 "expected the same rank for the vector and the "
5784 "results of the permutation map");
5790 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5792 {1, 1, static_cast<int32_t>(indexInfo.size()),
5793 static_cast<int32_t>(hasMask.succeeded())}));
5794 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5798void TransferWriteOp::print(OpAsmPrinter &p) {
5801 p <<
", " << getMask();
5806LogicalResult TransferWriteOp::verify() {
5808 ShapedType shapedType = getShapedType();
5810 VectorType maskType = getMaskType();
5811 auto permutationMap = getPermutationMap();
5812 VectorType inferredMaskType =
5816 if (llvm::size(
getIndices()) != shapedType.getRank())
5817 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5821 if (hasBroadcastDim())
5822 return emitOpError(
"should not have broadcast dimensions");
5825 shapedType, vectorType, maskType,
5826 inferredMaskType, permutationMap, getInBounds())))
5839Type TransferWriteOp::getExpectedMaskType() {
5846Value TransferWriteOp::getVector() {
return getOperand(0); }
5847VectorType TransferWriteOp::getVectorType() {
5848 return cast<VectorType>(getValueToStore().
getType());
5871static LogicalResult foldReadInitWrite(TransferWriteOp write,
5872 ArrayRef<Attribute>,
5873 SmallVectorImpl<OpFoldResult> &results) {
5875 if (write.getTransferRank() == 0)
5877 auto rankedTensorType =
5878 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5880 if (!rankedTensorType)
5883 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5887 if (read.getTransferRank() == 0)
5890 if (!read.getPermutationMap().isMinorIdentity() ||
5891 !write.getPermutationMap().isMinorIdentity())
5894 if (read.getTransferRank() != write.getTransferRank())
5897 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5900 if (read.getMask() || write.getMask())
5903 if (read.getBase().getType() != rankedTensorType)
5906 if (read.getVectorType() != write.getVectorType())
5909 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5912 auto isNotConstantZero = [](Value v) {
5914 return !cstOp.has_value() || cstOp.value() != 0;
5916 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5917 llvm::any_of(write.getIndices(), isNotConstantZero))
5920 results.push_back(read.getBase());
5924static bool checkSameValueWAR(vector::TransferReadOp read,
5925 vector::TransferWriteOp write) {
5926 return read.getBase() == write.getBase() &&
5927 read.getIndices() == write.getIndices() &&
5928 read.getPermutationMap() == write.getPermutationMap() &&
5929 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5946static LogicalResult foldWAR(TransferWriteOp write,
5947 SmallVectorImpl<OpFoldResult> &results) {
5948 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5950 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5954 if (!checkSameValueWAR(read, write))
5956 results.push_back(read.getBase());
5960LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5961 SmallVectorImpl<OpFoldResult> &results) {
5962 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5964 if (succeeded(foldWAR(*
this, results)))
5978std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5982void TransferWriteOp::getEffects(
5983 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5985 if (llvm::isa<MemRefType>(getShapedType()))
5986 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5987 SideEffects::DefaultResource::get());
5991 if (hasPureTensorSemantics())
6021class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
6024 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
6025 PatternRewriter &rewriter)
const override {
6026 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
6028 vector::TransferWriteOp writeToModify = writeOp;
6030 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
6034 writeToModify.getBaseMutable().assign(defWrite.getBase());
6039 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6040 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6044 if (!defWrite->hasOneUse())
6046 writeToModify = defWrite;
6047 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6076struct SwapExtractSliceOfTransferWrite
6077 :
public OpRewritePattern<tensor::InsertSliceOp> {
6081 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6082 PatternRewriter &rewriter)
const override {
6083 if (!insertOp.hasUnitStride())
6086 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6087 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6089 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6090 if (!transferOp || !transferOp->hasOneUse())
6095 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6097 "use-def chain is rank-reducing");
6101 if (!extractOp.hasZeroOffset()) {
6103 "ExtractSliceOp has non-zero offset");
6107 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6108 return getConstantIntValue(value) == static_cast<int64_t>(0);
6111 "TranferWriteOp has non-zero offset");
6115 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6117 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
6120 for (
auto [insertSize, extractSize] :
6121 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6124 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
6129 assert(transferOp.getVectorType().hasStaticShape() &&
6130 "expected vector to have a static shape");
6131 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
6133 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6134 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
6136 insertOp,
"TransferWriteOp may not write the full tensor.");
6141 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
6142 auto newExtractOp = tensor::ExtractSliceOp::create(
6143 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6144 insertOp.getDest(), insertOp.getMixedOffsets(),
6145 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6146 auto newTransferWriteOp = TransferWriteOp::create(
6147 rewriter, transferOp.getLoc(), transferOp.getVector(),
6148 newExtractOp.getResult(), transferOp.getIndices(),
6149 transferOp.getPermutationMapAttr(),
6152 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6160void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6161 MLIRContext *context) {
6162 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6165FailureOr<std::optional<SmallVector<Value>>>
6166TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6167 if (!hasPureBufferSemantics())
6177static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6179 MemRefType memRefTy) {
6182 if (!vecTy.isScalable() &&
6183 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6186 if (!memRefTy.isLastDimUnitStride())
6187 return op->
emitOpError(
"most minor memref dim must have unit stride");
6191LogicalResult vector::LoadOp::verify() {
6195 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
6200 return emitOpError(
"memref strides must be non-negative");
6202 if (memRefTy.getRank() < resVecTy.getRank())
6204 "destination memref has lower rank than the result vector");
6207 Type memElemTy = memRefTy.getElementType();
6208 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6209 if (memVecTy != resVecTy)
6210 return emitOpError(
"base memref and result vector types should match");
6211 memElemTy = memVecTy.getElementType();
6214 if (resVecTy.getElementType() != memElemTy)
6215 return emitOpError(
"base and result element types should match");
6216 if (llvm::size(
getIndices()) != memRefTy.getRank())
6217 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6221OpFoldResult LoadOp::fold(FoldAdaptor) {
6224 return OpFoldResult();
6227std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6231FailureOr<std::optional<SmallVector<Value>>>
6232LoadOp::bubbleDownCasts(OpBuilder &builder) {
6241LogicalResult vector::StoreOp::verify() {
6245 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
6250 return emitOpError(
"memref strides must be non-negative");
6252 if (memRefTy.getRank() < valueVecTy.getRank())
6253 return emitOpError(
"source memref has lower rank than the vector to store");
6256 Type memElemTy = memRefTy.getElementType();
6257 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6258 if (memVecTy != valueVecTy)
6260 "base memref and valueToStore vector types should match");
6261 memElemTy = memVecTy.getElementType();
6264 if (valueVecTy.getElementType() != memElemTy)
6265 return emitOpError(
"base and valueToStore element type should match");
6266 if (llvm::size(
getIndices()) != memRefTy.getRank())
6267 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6271LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6272 SmallVectorImpl<OpFoldResult> &results) {
6276std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6280FailureOr<std::optional<SmallVector<Value>>>
6281StoreOp::bubbleDownCasts(OpBuilder &builder) {
6290LogicalResult MaskedLoadOp::verify() {
6291 VectorType maskVType = getMaskVectorType();
6292 VectorType passVType = getPassThruVectorType();
6299 if (llvm::size(
getIndices()) != memType.getRank())
6300 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6301 if (resVType.getShape() != maskVType.getShape())
6302 return emitOpError(
"expected result shape to match mask shape");
6303 if (resVType != passVType)
6304 return emitOpError(
"expected pass_thru of same type as result type");
6309class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
6312 LogicalResult matchAndRewrite(MaskedLoadOp
load,
6313 PatternRewriter &rewriter)
const override {
6325 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
6330void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6331 MLIRContext *context) {
6332 results.
add<MaskedLoadFolder>(context);
6335OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6338 return OpFoldResult();
6341FailureOr<std::optional<SmallVector<Value>>>
6342MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6351LogicalResult MaskedStoreOp::verify() {
6352 VectorType maskVType = getMaskVectorType();
6359 if (llvm::size(
getIndices()) != memType.getRank())
6360 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6361 if (valueVType.getShape() != maskVType.getShape())
6362 return emitOpError(
"expected valueToStore shape to match mask shape");
6367class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6370 LogicalResult matchAndRewrite(MaskedStoreOp store,
6371 PatternRewriter &rewriter)
const override {
6375 store, store.getValueToStore(), store.getBase(), store.getIndices());
6383 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6388void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6389 MLIRContext *context) {
6390 results.
add<MaskedStoreFolder>(context);
6393LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6394 SmallVectorImpl<OpFoldResult> &results) {
6398FailureOr<std::optional<SmallVector<Value>>>
6399MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6408LogicalResult GatherOp::verify() {
6409 VectorType indVType = getIndexVectorType();
6410 VectorType maskVType = getMaskVectorType();
6412 ShapedType baseType = getBaseType();
6414 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6415 return emitOpError(
"requires base to be a memref or ranked tensor type");
6420 if (llvm::size(getOffsets()) != baseType.getRank())
6421 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6422 if (resVType.getShape() != indVType.getShape())
6423 return emitOpError(
"expected result dim to match indices dim");
6424 if (resVType.getShape() != maskVType.getShape())
6425 return emitOpError(
"expected result dim to match mask dim");
6426 if (resVType != getPassThruVectorType())
6427 return emitOpError(
"expected pass_thru of same type as result type");
6428 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6430 "alignment is only supported for memref bases, not tensor bases");
6439Type GatherOp::getExpectedMaskType() {
6440 auto vecType = this->getIndexVectorType();
6441 return VectorType::get(vecType.getShape(),
6442 IntegerType::get(vecType.getContext(), 1),
6443 vecType.getScalableDims());
6446std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6451static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6452 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6453 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6459 DenseIntElementsAttr elements;
6464 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6468class GatherFolder final :
public OpRewritePattern<GatherOp> {
6471 LogicalResult matchAndRewrite(GatherOp gather,
6472 PatternRewriter &rewriter)
const override {
6477 rewriter.
replaceOp(gather, gather.getPassThru());
6482 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6488class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6491 LogicalResult matchAndRewrite(GatherOp op,
6492 PatternRewriter &rewriter)
const override {
6493 if (!isa<MemRefType>(op.getBase().getType()))
6496 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6500 op.getOffsets(), op.getMask(),
6507void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6508 MLIRContext *context) {
6509 results.
add<GatherFolder, FoldContiguousGather>(context);
6512FailureOr<std::optional<SmallVector<Value>>>
6513GatherOp::bubbleDownCasts(OpBuilder &builder) {
6522LogicalResult ScatterOp::verify() {
6523 VectorType indVType = getIndexVectorType();
6524 VectorType maskVType = getMaskVectorType();
6526 ShapedType baseType = getBaseType();
6528 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6529 return emitOpError(
"requires base to be a memref or ranked tensor type");
6534 if (llvm::size(getOffsets()) != baseType.getRank())
6535 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6536 if (valueVType.getShape() != indVType.getShape())
6537 return emitOpError(
"expected valueToStore dim to match indices dim");
6538 if (valueVType.getShape() != maskVType.getShape())
6539 return emitOpError(
"expected valueToStore dim to match mask dim");
6540 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6542 "alignment is only supported for memref bases, not tensor bases");
6547class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6550 LogicalResult matchAndRewrite(ScatterOp scatter,
6551 PatternRewriter &rewriter)
const override {
6552 ShapedType baseType = scatter.getBaseType();
6553 bool isMemRef = isa<MemRefType>(baseType);
6554 if (!isMemRef && !isa<RankedTensorType>(baseType))
6567 rewriter.
replaceOp(scatter, scatter.getBase());
6572 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6578class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6581 LogicalResult matchAndRewrite(ScatterOp op,
6582 PatternRewriter &rewriter)
const override {
6585 if (!isa<MemRefType>(op.getBase().getType()))
6588 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6592 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6598void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6599 MLIRContext *context) {
6600 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6603FailureOr<std::optional<SmallVector<Value>>>
6604ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6613LogicalResult ExpandLoadOp::verify() {
6614 VectorType maskVType = getMaskVectorType();
6615 VectorType passVType = getPassThruVectorType();
6622 if (llvm::size(
getIndices()) != memType.getRank())
6623 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6624 if (resVType.getShape() != maskVType.getShape())
6625 return emitOpError(
"expected result shape to match mask shape");
6626 if (resVType != passVType)
6627 return emitOpError(
"expected pass_thru of same type as result type");
6632class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6635 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6636 PatternRewriter &rewriter)
const override {
6640 expand, expand.getType(), expand.getBase(), expand.getIndices());
6643 rewriter.
replaceOp(expand, expand.getPassThru());
6648 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6653void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6654 MLIRContext *context) {
6655 results.
add<ExpandLoadFolder>(context);
6658FailureOr<std::optional<SmallVector<Value>>>
6659ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6668LogicalResult CompressStoreOp::verify() {
6669 VectorType maskVType = getMaskVectorType();
6676 if (llvm::size(
getIndices()) != memType.getRank())
6677 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6678 if (valueVType.getShape() != maskVType.getShape())
6679 return emitOpError(
"expected valueToStore shape to match mask shape");
6684class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6687 LogicalResult matchAndRewrite(CompressStoreOp compress,
6688 PatternRewriter &rewriter)
const override {
6692 compress, compress.getValueToStore(), compress.getBase(),
6693 compress.getIndices());
6701 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6706void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6707 MLIRContext *context) {
6708 results.
add<CompressStoreFolder>(context);
6711FailureOr<std::optional<SmallVector<Value>>>
6712CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6721void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6723 setResultRanges(getResult(), argRanges.front());
6726std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6727 return llvm::to_vector<4>(getResultVectorType().
getShape());
6730LogicalResult ShapeCastOp::verify() {
6732 VectorType sourceType = getSourceVectorType();
6733 VectorType resultType = getResultVectorType();
6741 int64_t sourceNElms = sourceType.getNumElements();
6742 int64_t resultNElms = resultType.getNumElements();
6743 if (sourceNElms != resultNElms) {
6744 return emitOpError() <<
"has different number of elements at source ("
6745 << sourceNElms <<
") and result (" << resultNElms
6750 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6751 int64_t resultNScalableDims = resultType.getNumScalableDims();
6752 if (sourceNScalableDims != resultNScalableDims)
6753 return emitOpError() <<
"has different number of scalable dims at source ("
6754 << sourceNScalableDims <<
") and result ("
6755 << resultNScalableDims <<
")";
6764static bool isOrderPreserving(TransposeOp transpose) {
6765 ArrayRef<int64_t> permutation = transpose.getPermutation();
6766 VectorType sourceType = transpose.getSourceVectorType();
6767 ArrayRef<int64_t> inShape = sourceType.getShape();
6768 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6769 auto isNonScalableUnitDim = [&](int64_t dim) {
6770 return inShape[dim] == 1 && !inDimIsScalable[dim];
6772 int64_t current = 0;
6773 for (
auto p : permutation) {
6774 if (!isNonScalableUnitDim(p)) {
6784OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6786 VectorType resultType =
getType();
6789 if (getSource().
getType() == resultType)
6793 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6794 setOperand(precedingShapeCast.getSource());
6799 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6800 if (isOrderPreserving(transpose)) {
6801 setOperand(transpose.getVector());
6809 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6810 if (bcastOp.getSourceType() == resultType)
6811 return bcastOp.getSource();
6815 if (
auto denseAttr =
6816 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6817 return denseAttr.reshape(
getType());
6833static VectorType trimTrailingOneDims(VectorType oldType) {
6834 ArrayRef<int64_t> oldShape = oldType.getShape();
6835 ArrayRef<int64_t> newShape = oldShape;
6837 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6838 ArrayRef<bool> newScalableDims = oldScalableDims;
6840 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6841 newShape = newShape.drop_back(1);
6842 newScalableDims = newScalableDims.drop_back(1);
6847 if (newShape.empty()) {
6848 newShape = oldShape.take_back();
6849 newScalableDims = oldScalableDims.take_back();
6852 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6867class ShapeCastCreateMaskFolderTrailingOneDim final
6868 :
public OpRewritePattern<ShapeCastOp> {
6872 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6873 PatternRewriter &rewriter)
const override {
6874 Value shapeOpSrc = shapeOp->getOperand(0);
6875 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6876 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6877 if (!createMaskOp && !constantMaskOp)
6880 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6881 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6883 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6884 if (newVecType != shapeOpResTy)
6887 auto numDimsToDrop =
6888 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6895 auto maskOperands = createMaskOp.getOperands();
6896 auto numMaskOperands = maskOperands.size();
6899 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6901 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6902 if (!constant || (constant.value() != 1))
6905 SmallVector<Value> newMaskOperands =
6906 maskOperands.drop_back(numDimsToDrop);
6913 if (constantMaskOp) {
6914 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6915 auto numMaskOperands = maskDimSizes.size();
6918 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6920 if (maskDimSizes[i] != 1)
6924 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6937int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6938 ArrayRef<int64_t> dstShape) {
6939 int stretchingFactor = 1;
6940 int numLeadingDims = dstShape.size() - srcShape.size();
6941 for (
int i = 0, e = srcShape.size(); i < e; i++) {
6942 int64_t dstDim = dstShape[numLeadingDims + i];
6943 if (srcShape[i] == 1 && dstDim != 1) {
6944 stretchingFactor *= dstDim;
6947 return stretchingFactor;
6951class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6955 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6956 PatternRewriter &rewriter)
const override {
6958 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6962 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6963 bool srcIsScalar = !srcVectorType;
6971 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6972 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6973 ArrayRef<int64_t> srcShape =
6974 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6975 ArrayRef<int64_t> broadcastShape =
6976 broadcastOp.getResultVectorType().getShape();
6980 BroadcastableToResult::Success) {
6988 if (srcVectorType.getNumElements() != 1) {
6989 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6990 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6997 broadcastOp.getSource());
7016class FoldShapeCastOfFromElements final :
public OpRewritePattern<ShapeCastOp> {
7020 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
7021 PatternRewriter &rewriter)
const override {
7022 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
7027 shapeCastOp, shapeCastOp.getResultVectorType(),
7028 fromElements.getElements());
7035void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
7036 MLIRContext *context) {
7037 results.
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
7038 FoldShapeCastOfFromElements>(context);
7045LogicalResult BitCastOp::verify() {
7046 auto sourceVectorType = getSourceVectorType();
7047 auto resultVectorType = getResultVectorType();
7049 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7050 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7051 return emitOpError(
"dimension size mismatch at: ") << i;
7054 DataLayout dataLayout = DataLayout::closest(*
this);
7055 auto sourceElementBits =
7057 auto resultElementBits =
7060 if (sourceVectorType.getRank() == 0) {
7061 if (sourceElementBits != resultElementBits)
7062 return emitOpError(
"source/result bitwidth of the 0-D vector element "
7063 "types must be equal");
7064 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
7065 resultElementBits * resultVectorType.getShape().back()) {
7067 "source/result bitwidth of the minor 1-D vectors must be equal");
7073OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7079 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7080 if (getResult().
getType() == otherOp.getSource().getType())
7081 return otherOp.getSource();
7083 setOperand(otherOp.getSource());
7087 Attribute sourceConstant = adaptor.getSource();
7088 if (!sourceConstant)
7091 Type srcElemType = getSourceVectorType().getElementType();
7092 Type dstElemType = getResultVectorType().getElementType();
7094 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7095 if (floatPack.isSplat()) {
7096 auto splat = floatPack.getSplatValue<FloatAttr>();
7099 if (srcElemType.
isF16() && dstElemType.
isF32()) {
7100 uint32_t bits =
static_cast<uint32_t
>(
7101 splat.getValue().bitcastToAPInt().getZExtValue());
7103 bits = (bits << 16) | (bits & 0xffff);
7104 APInt intBits(32, bits);
7105 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7111 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7112 if (intPack.isSplat()) {
7113 auto splat = intPack.getSplatValue<IntegerAttr>();
7115 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.
isIntOrFloat()) {
7120 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7121 APInt intBits = splat.getValue().zext(dstBitWidth);
7124 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7125 intBits = (intBits << srcBitWidth) | intBits;
7135std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
7136 return llvm::to_vector<4>(getResultVectorType().
getShape());
7143static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7144 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7145 SmallVector<int64_t, 8> res(memRefType.getShape());
7147 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7153void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
7155 result.addOperands(source);
7156 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
7157 VectorType vectorType =
7158 VectorType::get(extractShape(memRefType),
7160 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7161 memRefType.getMemorySpace()));
7164LogicalResult TypeCastOp::verify() {
7165 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
7166 if (!canonicalType.getLayout().isIdentity())
7167 return emitOpError(
"expects operand to be a memref with identity layout");
7168 if (!getResultMemRefType().getLayout().isIdentity())
7169 return emitOpError(
"expects result to be a memref with identity layout");
7170 if (getResultMemRefType().getMemorySpace() !=
7172 return emitOpError(
"expects result in same memory space");
7175 auto resultType = getResultMemRefType();
7179 "expects result and operand with same underlying scalar type: ")
7181 if (extractShape(sourceType) != extractShape(resultType))
7183 "expects concatenated result and operand shapes to be equal: ")
7192void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
7193 Value vector, ArrayRef<int64_t> permutation) {
7194 VectorType vt = llvm::cast<VectorType>(vector.
getType());
7195 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7196 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7197 for (
unsigned i = 0; i < permutation.size(); ++i) {
7198 transposedShape[i] = vt.getShape()[permutation[i]];
7199 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7202 result.addOperands(vector);
7203 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7204 transposedScalableDims));
7205 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
7209OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7212 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7213 return splat.reshape(getResultVectorType());
7230 if (getSourceVectorType() == getResultVectorType() &&
7231 isOrderPreserving(*
this))
7237LogicalResult vector::TransposeOp::verify() {
7238 VectorType vectorType = getSourceVectorType();
7239 VectorType resultType = getResultVectorType();
7240 int64_t rank = resultType.getRank();
7241 if (vectorType.getRank() != rank)
7242 return emitOpError(
"vector result rank mismatch: ") << rank;
7244 ArrayRef<int64_t> perm = getPermutation();
7245 int64_t size = perm.size();
7247 return emitOpError(
"transposition length mismatch: ") << size;
7248 SmallVector<bool, 8> seen(rank,
false);
7249 for (
const auto &ta : llvm::enumerate(perm)) {
7250 if (ta.value() < 0 || ta.value() >= rank)
7251 return emitOpError(
"transposition index out of range: ") << ta.value();
7252 if (seen[ta.value()])
7253 return emitOpError(
"duplicate position index: ") << ta.value();
7254 seen[ta.value()] =
true;
7255 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7256 return emitOpError(
"dimension size mismatch at: ") << ta.value();
7261std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7262 return llvm::to_vector<4>(getResultVectorType().
getShape());
7265void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7267 setResultRanges(getResult(), argRanges.front());
7273class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
7277 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7278 PatternRewriter &rewriter)
const override {
7280 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7281 ArrayRef<int64_t> permutation2) {
7282 SmallVector<int64_t, 4>
result;
7283 for (
auto index : permutation2)
7284 result.push_back(permutation1[index]);
7289 vector::TransposeOp parentTransposeOp =
7290 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7291 if (!parentTransposeOp)
7294 SmallVector<int64_t, 4> permutation = composePermutations(
7295 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7298 transposeOp, transposeOp.getResult().
getType(),
7299 parentTransposeOp.getVector(), permutation);
7305class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
7309 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7310 PatternRewriter &rewriter)
const override {
7311 Value splat = getScalarSplatSource(transposeOp.getVector());
7316 transposeOp, transposeOp.getResultVectorType(), splat);
7322class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
7326 LogicalResult matchAndRewrite(TransposeOp transpOp,
7327 PatternRewriter &rewriter)
const override {
7328 Value transposeSrc = transpOp.getVector();
7329 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
7330 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
7331 if (!createMaskOp && !constantMaskOp)
7336 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7339 auto maskOperands = createMaskOp.getOperands();
7340 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7344 transpOp, transpOp.getResultVectorType(), newOperands);
7349 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7353 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7359class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
7363 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7364 PatternRewriter &rewriter)
const override {
7366 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7369 if (!isOrderPreserving(transposeOp))
7372 VectorType resultType = transposeOp.getType();
7379 shapeCastOp.getSource());
7398class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
7401 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7402 PatternRewriter &rewriter)
const override {
7403 auto fromElementsOp =
7404 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7405 if (!fromElementsOp)
7408 VectorType srcTy = fromElementsOp.getDest().getType();
7409 VectorType dstTy = transposeOp.getType();
7411 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7412 int64_t rank = srcTy.getRank();
7415 SmallVector<int64_t> inversePerm(rank, 0);
7416 for (int64_t i = 0; i < rank; ++i)
7417 inversePerm[permutation[i]] = i;
7419 ArrayRef<int64_t> srcShape = srcTy.getShape();
7420 ArrayRef<int64_t> dstShape = dstTy.getShape();
7421 SmallVector<int64_t> srcIdx(rank, 0);
7422 SmallVector<int64_t> dstIdx(rank, 0);
7426 auto elementsOld = fromElementsOp.getElements();
7427 SmallVector<Value> elementsNew;
7428 int64_t dstNumElements = dstTy.getNumElements();
7429 elementsNew.reserve(dstNumElements);
7433 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7437 for (int64_t j = 0; j < rank; ++j)
7438 srcIdx[j] = dstIdx[inversePerm[j]];
7440 int64_t srcLin =
linearize(srcIdx, srcStrides);
7442 elementsNew.push_back(elementsOld[srcLin]);
7476class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7479 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7480 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7482 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7483 PatternRewriter &rewriter)
const override {
7489 "not preceded by a broadcast");
7492 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7493 VectorType outputType = transpose.getResultVectorType();
7496 bool inputIsScalar = !inputType;
7497 if (inputIsScalar) {
7503 ArrayRef<int64_t> permutation = transpose.getPermutation();
7504 ArrayRef<int64_t> inputShape = inputType.getShape();
7505 int64_t inputRank = inputType.getRank();
7506 int64_t outputRank = transpose.getType().getRank();
7507 int64_t deltaRank = outputRank - inputRank;
7510 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7511 bool notOne = inputShape[inputIndex] != 1;
7512 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7513 bool groupEndFound = notOne || prevNotOne;
7514 if (groupEndFound) {
7515 int high = inputIndex + deltaRank;
7519 for (
int i = low; i < high; ++i) {
7520 if (permutation[i] < low || permutation[i] >= high) {
7522 transpose,
"permutation not local to group");
7536 vector::BroadcastableToResult::Success &&
7537 "not broadcastable directly to transpose output");
7548void vector::TransposeOp::getCanonicalizationPatterns(
7549 RewritePatternSet &results, MLIRContext *context) {
7550 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7551 FoldTransposeSplat, FoldTransposeFromElements,
7552 FoldTransposeBroadcast>(context);
7559void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7561 assert(kind == ConstantMaskKind::AllTrue ||
7562 kind == ConstantMaskKind::AllFalse);
7563 build(builder,
result, type,
7564 kind == ConstantMaskKind::AllTrue
7566 : SmallVector<int64_t>(type.getRank(), 0));
7569LogicalResult ConstantMaskOp::verify() {
7570 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7572 if (resultType.getRank() == 0) {
7573 if (getMaskDimSizes().size() != 1)
7574 return emitError(
"array attr must have length 1 for 0-D vectors");
7575 auto dim = getMaskDimSizes()[0];
7576 if (dim != 0 && dim != 1)
7577 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7582 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7584 "must specify array attr of size equal vector result rank");
7587 auto resultShape = resultType.getShape();
7588 auto resultScalableDims = resultType.getScalableDims();
7589 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7590 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7591 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7593 "array attr of size out of bounds of vector result dimension size");
7594 if (resultScalableDims[index] && maskDimSize != 0 &&
7595 maskDimSize != resultShape[index])
7597 "only supports 'none set' or 'all set' scalable dimensions");
7601 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7602 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7603 if (anyZeros && !allZeros)
7604 return emitOpError(
"expected all mask dim sizes to be zeros, "
7605 "as a result of conjunction with zero mask dim");
7609bool ConstantMaskOp::isAllOnesMask() {
7612 if (resultType.getRank() == 0) {
7613 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7614 return getMaskDimSizes()[0] == 1;
7616 for (
const auto [resultSize, maskDimSize] :
7617 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7618 if (maskDimSize < resultSize)
7624OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7625 ArrayRef<int64_t> bounds = getMaskDimSizes();
7628 auto createBoolSplat = [&](
bool x) {
7634 if (vectorSizes.empty()) {
7635 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7636 return createBoolSplat(bounds[0] == 1);
7639 if (bounds == vectorSizes)
7640 return createBoolSplat(
true);
7641 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7642 return createBoolSplat(
false);
7643 return OpFoldResult();
7650void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7652 ArrayRef<OpFoldResult> mixedOperands) {
7653 SmallVector<Value> operands =
7655 build(builder,
result, type, operands);
7658LogicalResult CreateMaskOp::verify() {
7659 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7661 if (vectorType.getRank() == 0) {
7662 if (getNumOperands() != 1)
7664 "must specify exactly one operand for 0-D create_mask");
7665 }
else if (getNumOperands() !=
7666 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7668 "must specify an operand for each result vector dimension");
7698class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7702 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7703 PatternRewriter &rewriter)
const override {
7704 VectorType maskType = createMaskOp.getVectorType();
7705 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7706 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7709 constexpr std::array<int64_t, 1> rankZeroShape{1};
7710 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7711 if (maskType.getRank() == 0) {
7712 maskTypeDimSizes = rankZeroShape;
7713 maskTypeDimScalableFlags = rankZeroScalableDims;
7718 SmallVector<int64_t, 4> constantDims;
7719 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7724 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7726 constantDims.push_back(*intSize);
7730 if (vscaleMultiplier < maskTypeDimSizes[i])
7732 constantDims.push_back(*vscaleMultiplier);
7739 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7740 value = std::clamp<int64_t>(value, 0, maskDimSize);
7743 if (llvm::is_contained(constantDims, 0))
7744 constantDims.assign(constantDims.size(), 0);
7755void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7756 MLIRContext *context) {
7757 results.
add<CreateMaskFolder>(context);
7765 OpBuilder &builder, OperationState &
result, Value mask,
7766 Operation *maskableOp,
7767 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7768 assert(maskRegionBuilder &&
7769 "builder callback for 'maskRegion' must be present");
7771 result.addOperands(mask);
7772 OpBuilder::InsertionGuard guard(builder);
7773 Region *maskRegion =
result.addRegion();
7775 maskRegionBuilder(builder, maskableOp);
7780 Value mask, Operation *maskableOp,
7781 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7782 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7788 Value mask, Value passthru, Operation *maskableOp,
7789 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7790 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7792 result.addOperands(passthru);
7793 result.addTypes(resultTypes);
7796ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7798 result.regions.reserve(1);
7799 Region &maskRegion = *
result.addRegion();
7804 OpAsmParser::UnresolvedOperand mask;
7809 OpAsmParser::UnresolvedOperand passthru;
7811 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7818 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7829 SmallVector<Type> resultTypes;
7832 result.types.append(resultTypes);
7838 if (parsePassthru.succeeded()) {
7839 if (resultTypes.empty())
7842 "expects a result if passthru operand is provided");
7851void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7852 p <<
" " << getMask();
7854 p <<
", " << getPassthru();
7858 Block *singleBlock = &getMaskRegion().getBlocks().front();
7865 p <<
" : " << getMask().getType();
7866 if (getNumResults() > 0)
7867 p <<
" -> " << getResultTypes();
7870void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7873 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7874 MaskOp>::ensureTerminator(region, builder, loc);
7880 if (isa<vector::YieldOp>(block.
back()))
7888 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7889 MaskOp>::ensureTerminator(region, builder, loc);
7895 Operation *maskedOp = &block.
front();
7896 opBuilder.setInsertionPointToEnd(&block);
7897 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7900LogicalResult MaskOp::verify() {
7902 Block &block = getMaskRegion().getBlocks().
front();
7904 return emitOpError(
"expects a terminator within the mask region");
7907 if (numMaskRegionOps > 2)
7908 return emitOpError(
"expects only one operation to mask");
7911 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7913 return emitOpError(
"expects a terminator within the mask region");
7915 if (terminator->getNumOperands() != getNumResults())
7917 "expects number of results to match mask region yielded values");
7920 if (numMaskRegionOps == 1)
7923 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7925 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7929 return emitOpError(
"expects number of results to match maskable operation "
7930 "number of results");
7932 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7933 return emitOpError(
"expects all the results from the MaskableOpInterface "
7934 "to match all the values returned by the terminator");
7936 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7938 "expects result type to match maskable operation result type");
7941 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7942 return emitOpError(
"multiple vector results not supported");
7945 Type expectedMaskType = maskableOp.getExpectedMaskType();
7946 if (getMask().
getType() != expectedMaskType)
7948 << expectedMaskType <<
" mask for the maskable operation";
7951 Value passthru = getPassthru();
7953 if (!maskableOp.supportsPassthru())
7955 "doesn't expect a passthru argument for this maskable operation");
7958 return emitOpError(
"expects result when passthru argument is provided");
7961 return emitOpError(
"expects passthru type to match result type");
7981static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7982 SmallVectorImpl<OpFoldResult> &results) {
7983 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7986 Block *block = maskOp.getMaskBlock();
7987 auto terminator = cast<vector::YieldOp>(block->
front());
7988 if (terminator.getNumOperands() == 0)
7992 llvm::append_range(results, terminator.getOperands());
7996LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7997 SmallVectorImpl<OpFoldResult> &results) {
7998 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
8008 Operation *maskableOp = getMaskableOp();
8014 llvm::append_range(results, maskableOp->
getResults());
8030class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
8033 LogicalResult matchAndRewrite(MaskOp maskOp,
8034 PatternRewriter &rewriter)
const override {
8035 if (!maskOp.isEmpty())
8038 if (!maskOp.hasPassthru())
8045 VectorType maskType = maskOp.getMask().getType();
8046 for (Type resultType : maskOp.getResultTypes()) {
8047 auto vecResultType = dyn_cast<VectorType>(resultType);
8048 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8052 Block *block = maskOp.getMaskBlock();
8053 auto terminator = cast<vector::YieldOp>(block->
front());
8054 assert(terminator.getNumOperands() == 1 &&
8055 "expected one result when passthru is provided");
8058 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8059 terminator.getOperand(0), maskOp.getPassthru());
8065void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8066 MLIRContext *context) {
8067 results.
add<CanonializeEmptyMaskOp>(context);
8073Operation *MaskOp::getMaskableOp() {
8074 Block *block = getMaskBlock();
8078 return &block->
front();
8082bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
8088LogicalResult ScanOp::verify() {
8089 VectorType srcType = getSourceType();
8090 VectorType initialType = getInitialValueType();
8092 int64_t srcRank = srcType.getRank();
8093 int64_t reductionDim = getReductionDim();
8094 if (reductionDim >= srcRank)
8096 << reductionDim <<
" has to be less than " << srcRank;
8099 int64_t initialValueRank = initialType.getRank();
8100 if (initialValueRank != srcRank - 1)
8102 << initialValueRank <<
" has to be equal to " << srcRank - 1;
8105 ArrayRef<int64_t> srcShape = srcType.getShape();
8106 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8107 SmallVector<int64_t> expectedShape;
8108 for (
int i = 0; i < srcRank; i++) {
8109 if (i != reductionDim)
8110 expectedShape.push_back(srcShape[i]);
8112 if (!llvm::equal(initialValueShapes, expectedShape)) {
8113 return emitOpError(
"incompatible input/initial value shapes");
8117 Type eltType = getDestType().getElementType();
8120 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
8127 RewritePatternSet &patterns, PatternBenefit benefit) {
8129 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8130 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8131 StridedSliceConstantMaskFolder, TransposeFolder>(
8136 CombiningKind kind, Value v1, Value acc,
8137 arith::FastMathFlagsAttr fastmath,
8144 case CombiningKind::ADD:
8146 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
8147 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8148 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8150 llvm_unreachable(
"invalid value types for ADD reduction");
8152 case CombiningKind::AND:
8154 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
8156 case CombiningKind::MAXNUMF:
8157 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8158 "expected float values");
8159 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8161 case CombiningKind::MAXIMUMF:
8162 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8163 "expected float values");
8164 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8166 case CombiningKind::MINNUMF:
8167 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8168 "expected float values");
8169 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8171 case CombiningKind::MINIMUMF:
8172 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8173 "expected float values");
8174 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8176 case CombiningKind::MAXSI:
8178 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8180 case CombiningKind::MINSI:
8182 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8184 case CombiningKind::MAXUI:
8186 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8188 case CombiningKind::MINUI:
8190 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8192 case CombiningKind::MUL:
8194 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
8195 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8196 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8198 llvm_unreachable(
"invalid value types for MUL reduction");
8200 case CombiningKind::OR:
8202 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
8204 case CombiningKind::XOR:
8206 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8210 assert(
result &&
"unknown CombiningKind");
8218void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8220 auto resultType = cast<VectorType>(
getType());
8221 if (resultType.isScalable()) {
8225 APInt zero(bitwidth, 0);
8226 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8227 ConstantIntRanges
result = {zero, high, zero, high};
8228 setResultRanges(getResult(),
result);
8258struct StepCompareFolder :
public OpRewritePattern<StepOp> {
8261 LogicalResult matchAndRewrite(StepOp stepOp,
8262 PatternRewriter &rewriter)
const override {
8263 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8265 for (OpOperand &use : stepOp.getResult().getUses()) {
8266 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8271 const unsigned stepOperandNumber = use.getOperandNumber();
8272 if (stepOperandNumber != 0)
8276 unsigned constOperandNumber = 1;
8277 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8278 std::optional<int64_t> maybeConstValue =
8280 if (!maybeConstValue.has_value())
8283 int64_t constValue = maybeConstValue.value();
8284 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8286 auto maybeSplat = [&]() -> std::optional<bool> {
8288 if ((pred == arith::CmpIPredicate::ult ||
8289 pred == arith::CmpIPredicate::uge) &&
8290 stepSize <= constValue)
8291 return pred == arith::CmpIPredicate::ult;
8294 if ((pred == arith::CmpIPredicate::ule ||
8295 pred == arith::CmpIPredicate::ugt) &&
8296 stepSize - 1 <= constValue) {
8297 return pred == arith::CmpIPredicate::ule;
8301 if ((pred == arith::CmpIPredicate::eq ||
8302 pred == arith::CmpIPredicate::ne) &&
8303 stepSize <= constValue)
8304 return pred == arith::CmpIPredicate::ne;
8306 return std::nullopt;
8309 if (!maybeSplat.has_value())
8314 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8319 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8331void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8332 MLIRContext *context) {
8333 results.
add<StepCompareFolder>(context);
8343 Operation *maskableOp) {
8344 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
8356 Operation *maskableOp, Value mask,
8361 return MaskOp::create(builder, maskableOp->
getLoc(),
8364 return MaskOp::create(builder, maskableOp->
getLoc(),
8377 Value newValue, Value passthru) {
8381 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
8382 mask, newValue, passthru);
8393struct InterleaveDeinterleaveFolder :
public OpRewritePattern<InterleaveOp> {
8396 LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
8397 PatternRewriter &rewriter)
const override {
8398 auto lhsDefOp = interleaveOp.getLhs().getDefiningOp<DeinterleaveOp>();
8399 auto rhsDefOp = interleaveOp.getRhs().getDefiningOp<DeinterleaveOp>();
8400 if (!lhsDefOp || !rhsDefOp || lhsDefOp != rhsDefOp)
8402 for (
auto [idx, operand] : llvm::enumerate(interleaveOp.getOperands())) {
8403 if (cast<OpResult>(operand).getResultNumber() != idx)
8406 rewriter.
replaceOp(interleaveOp, lhsDefOp.getSource());
8412void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
8413 MLIRContext *context) {
8414 results.
add<InterleaveDeinterleaveFolder>(context);
8417std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
8418 return llvm::to_vector<4>(getResultVectorType().
getShape());
8425std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
8426 return llvm::to_vector<4>(getResultVectorType().
getShape());
8433#define GET_ATTRDEF_CLASSES
8434#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8436#define GET_OP_CLASSES
8437#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static LogicalResult foldSize1TransferPermutationMap(TransferOp op)
When the vector type is vector<1xT>, the permutation map is irrelevant: the single vector lane always...
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
bool hasNegativeStaticStride(MemRefType memRefTy)
Returns true if any stride of memRefTy is statically known to be negative.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)