43#include "llvm/ADT/ArrayRef.h"
44#include "llvm/ADT/Repeated.h"
45#include "llvm/ADT/STLExtras.h"
46#include "llvm/ADT/SmallVector.h"
47#include "llvm/ADT/SmallVectorExtras.h"
48#include "llvm/ADT/StringSet.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Casting.h"
56#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
58#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
79 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
81 for (
bool b : denseElts.getValues<
bool>())
84 else if (!
b && val <= 0)
98 auto shape = m.getType().getShape();
100 bool allFalse =
true;
101 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
102 if (maskIdx < dimSize)
115 auto maskOperands = m.getOperands();
116 for (
Value operand : maskOperands) {
117 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
119 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
132 vector::YieldOp::create(builder, loc);
138 switch (combiningKind) {
139 case CombiningKind::ADD:
140 case CombiningKind::MUL:
142 case CombiningKind::MINUI:
143 case CombiningKind::MINSI:
144 case CombiningKind::MAXUI:
145 case CombiningKind::MAXSI:
146 case CombiningKind::AND:
147 case CombiningKind::OR:
148 case CombiningKind::XOR:
150 case CombiningKind::MINNUMF:
151 case CombiningKind::MAXNUMF:
152 case CombiningKind::MINIMUMF:
153 case CombiningKind::MAXIMUMF:
154 return llvm::isa<FloatType>(elementType);
184 VectorType vectorType) {
185 unsigned elementVectorRank = 0;
186 VectorType elementVectorType =
187 llvm::dyn_cast<VectorType>(shapedType.getElementType());
188 if (elementVectorType)
189 elementVectorRank += elementVectorType.getRank();
190 return vectorType.getRank() - elementVectorRank;
194 VectorType vectorType) {
197 if (shapedType.getRank() == 0 &&
203 shapedType.getRank(),
205 shapedType.getContext());
212 vector::TransferReadOp read) {
213 auto readMask = read.getMask();
214 auto writeMask = write.getMask();
220 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
221 if (!couldBeSameSplat)
238 vector::TransferReadOp read) {
239 return !defWrite.hasOutOfBoundsDim() &&
240 defWrite.getIndices() == read.getIndices() &&
241 defWrite.getVectorType() == read.getVectorType() &&
242 defWrite.getPermutationMap() == read.getPermutationMap() &&
243 ((!defWrite.getMask() && !read.getMask()) ||
248 vector::TransferWriteOp priorWrite) {
249 return priorWrite.getIndices() == write.getIndices() &&
250 priorWrite.getMask() == write.getMask() &&
251 priorWrite.getVectorType() == write.getVectorType() &&
252 priorWrite.getPermutationMap() == write.getPermutationMap();
256 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
257 bool testDynamicValueUsingBounds) {
259 if (transferA.getVectorType() != transferB.getVectorType())
261 unsigned rankOffset = transferA.getLeadingShapedRank();
262 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
263 Value indexA = transferA.getIndices()[i];
264 Value indexB = transferB.getIndices()[i];
268 if (i < rankOffset) {
271 if (cstIndexA.has_value() && cstIndexB.has_value()) {
272 if (*cstIndexA != *cstIndexB)
276 if (testDynamicValueUsingBounds) {
279 FailureOr<uint64_t> delta =
281 if (succeeded(delta) && *delta != 0)
284 FailureOr<bool> testEqual =
286 if (succeeded(testEqual) && !testEqual.value())
292 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
293 if (cstIndexA.has_value() && cstIndexB.has_value()) {
294 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
295 if (distance >= vectorDim)
299 if (testDynamicValueUsingBounds) {
302 FailureOr<int64_t> delta =
304 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
307 FailureOr<int64_t> computeDelta =
309 if (succeeded(computeDelta)) {
310 if (std::abs(computeDelta.value()) >= vectorDim)
320 VectorTransferOpInterface transferB,
321 bool testDynamicValueUsingBounds) {
322 if (transferA.getBase() != transferB.getBase())
325 testDynamicValueUsingBounds);
335 for (
auto [posInDim, dimSize, offsetInDim] :
336 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
338 if (posInDim < dimSize + offsetInDim)
342 posInDim = offsetInDim;
352 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
354 assert(constOp &&
"Unexpected non-constant index");
355 return constOp.value();
365 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
366 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
367 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
377 llvm::transform(foldResults, std::back_inserter(values),
379 if (
auto attr = dyn_cast<Attribute>(foldResult))
381 builder, loc, cast<IntegerAttr>(attr).getInt())
384 return cast<Value>(foldResult);
397 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
399 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
409 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
410 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
411 if (intAttr.getType() != expectedType)
412 return IntegerAttr::get(expectedType, intAttr.getInt());
418 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
419 auto intType = dyn_cast<IntegerType>(expectedType);
423 APFloat floatVal = floatAttr.getValue();
424 APInt intVal = floatVal.bitcastToAPInt();
425 return IntegerAttr::get(expectedType, intVal);
464struct VectorInlinerInterface :
public DialectInlinerInterface {
465 using DialectInlinerInterface::DialectInlinerInterface;
474void VectorDialect::initialize() {
476#define GET_ATTRDEF_LIST
477#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
482#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
485 addInterfaces<VectorInlinerInterface>();
487 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
488 MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
490 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
491 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
493 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
495 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
496 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
497 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
508 return arith::ConstantOp::materialize(builder, value, type, loc);
524void vector::MultiDimReductionOp::build(
OpBuilder &builder,
527 CombiningKind kind) {
529 for (
const auto &en : llvm::enumerate(reductionMask))
531 reductionDims.push_back(en.index());
532 build(builder,
result, kind, source,
acc, reductionDims);
535OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
537 if (getReductionDims().empty())
542std::optional<SmallVector<int64_t, 4>>
543MultiDimReductionOp::getShapeForUnroll() {
544 return llvm::to_vector<4>(getSourceVectorType().
getShape());
547LogicalResult MultiDimReductionOp::verify() {
550 Type inferredReturnType;
551 auto sourceScalableDims = getSourceVectorType().getScalableDims();
552 for (
auto [dimIdx, dimSize] :
553 llvm::enumerate(getSourceVectorType().
getShape()))
554 if (!llvm::any_of(getReductionDims(),
555 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
556 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
558 targetShape.push_back(dimSize);
559 scalableDims.push_back(sourceScalableDims[dimIdx]);
562 if (targetShape.empty())
563 inferredReturnType = getSourceVectorType().getElementType();
565 inferredReturnType = VectorType::get(
566 targetShape, getSourceVectorType().
getElementType(), scalableDims);
567 if (
getType() != inferredReturnType)
569 <<
" is incompatible with source type "
570 << getSourceVectorType();
576Type MultiDimReductionOp::getExpectedMaskType() {
577 auto vecType = getSourceVectorType();
578 return VectorType::get(vecType.getShape(),
579 IntegerType::get(vecType.getContext(), 1),
580 vecType.getScalableDims());
589struct ElideUnitDimsInMultiDimReduction
593 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
594 PatternRewriter &rewriter)
const override {
595 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
596 for (
const auto &dim :
enumerate(shape)) {
597 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
602 OpBuilder::InsertionGuard guard(rewriter);
605 if (reductionOp.isMasked()) {
607 rootOp = reductionOp.getMaskingOp();
608 mask = reductionOp.getMaskingOp().getMask();
610 rootOp = reductionOp;
613 Location loc = reductionOp.getLoc();
614 Value acc = reductionOp.getAcc();
616 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
618 VectorType newMaskType =
619 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
620 dstVecType.getScalableDims());
621 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
623 cast = vector::ShapeCastOp::create(
624 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
629 mask = vector::ExtractOp::create(rewriter, loc, mask);
630 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
635 cast,
nullptr, mask);
642void MultiDimReductionOp::getCanonicalizationPatterns(
644 results.
add<ElideUnitDimsInMultiDimReduction>(context);
653 arith::FastMathFlags fastMathFlags) {
659 arith::FastMathFlags fastMathFlags) {
661 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
665LogicalResult ReductionOp::verify() {
667 int64_t rank = getSourceVectorType().getRank();
669 return emitOpError(
"unsupported reduction rank: ") << rank;
672 Type eltType = getDest().getType();
675 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
684Type ReductionOp::getExpectedMaskType() {
685 auto vecType = getSourceVectorType();
686 return VectorType::get(vecType.getShape(),
687 IntegerType::get(vecType.getContext(), 1),
688 vecType.getScalableDims());
695 case arith::AtomicRMWKind::addf:
696 case arith::AtomicRMWKind::addi:
697 return vector::ReductionOp::create(builder,
vector.getLoc(),
698 CombiningKind::ADD,
vector);
699 case arith::AtomicRMWKind::mulf:
700 case arith::AtomicRMWKind::muli:
701 return vector::ReductionOp::create(builder,
vector.getLoc(),
702 CombiningKind::MUL,
vector);
703 case arith::AtomicRMWKind::minimumf:
704 return vector::ReductionOp::create(builder,
vector.getLoc(),
705 CombiningKind::MINIMUMF,
vector);
706 case arith::AtomicRMWKind::mins:
707 return vector::ReductionOp::create(builder,
vector.getLoc(),
708 CombiningKind::MINSI,
vector);
709 case arith::AtomicRMWKind::minu:
710 return vector::ReductionOp::create(builder,
vector.getLoc(),
711 CombiningKind::MINUI,
vector);
712 case arith::AtomicRMWKind::maximumf:
713 return vector::ReductionOp::create(builder,
vector.getLoc(),
714 CombiningKind::MAXIMUMF,
vector);
715 case arith::AtomicRMWKind::maxs:
716 return vector::ReductionOp::create(builder,
vector.getLoc(),
717 CombiningKind::MAXSI,
vector);
718 case arith::AtomicRMWKind::maxu:
719 return vector::ReductionOp::create(builder,
vector.getLoc(),
720 CombiningKind::MAXUI,
vector);
721 case arith::AtomicRMWKind::andi:
722 return vector::ReductionOp::create(builder,
vector.getLoc(),
723 CombiningKind::AND,
vector);
724 case arith::AtomicRMWKind::ori:
725 return vector::ReductionOp::create(builder,
vector.getLoc(),
726 CombiningKind::OR,
vector);
727 case arith::AtomicRMWKind::minnumf:
728 return vector::ReductionOp::create(builder,
vector.getLoc(),
729 CombiningKind::MINNUMF,
vector);
730 case arith::AtomicRMWKind::maxnumf:
731 return vector::ReductionOp::create(builder,
vector.getLoc(),
732 CombiningKind::MAXNUMF,
vector);
733 case arith::AtomicRMWKind::xori:
734 return vector::ReductionOp::create(builder,
vector.getLoc(),
735 CombiningKind::XOR,
vector);
743std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
744 return llvm::to_vector<4>(getSourceVectorType().
getShape());
751 LogicalResult matchAndRewrite(ReductionOp reductionOp,
756 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
759 if (maskableOp.isMasked()) {
761 rootOp = maskableOp.getMaskingOp();
762 mask = maskableOp.getMaskingOp().getMask();
764 rootOp = reductionOp;
767 auto vectorType = reductionOp.getSourceVectorType();
768 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
771 Location loc = reductionOp.getLoc();
773 mask = ExtractOp::create(rewriter, loc, mask);
774 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
776 if (Value acc = reductionOp.getAcc())
779 reductionOp.getFastmathAttr(), mask);
789 results.
add<ElideSingleElementReduction>(context);
803 getIndexingMapsAttrName(
result.name),
807 getIteratorTypesAttrName(
result.name),
810 return IteratorTypeAttr::get(builder.getContext(), t);
819 ContractionOp::getDefaultKind());
825 ArrayAttr iteratorTypes, CombiningKind kind,
826 arith::FastMathFlags fastMathFlags) {
829 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
830 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
832 CombiningKindAttr::get(builder.
getContext(), kind));
833 if (fastMathFlags != arith::FastMathFlags::none)
835 getFastmathAttrName(
result.name),
836 arith::FastMathFlagsAttr::get(builder.
getContext(), fastMathFlags));
847 DictionaryAttr dictAttr;
861 result.attributes.append(dictAttr.getValue().begin(),
862 dictAttr.getValue().end());
868 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
869 result.attributes.get(getIteratorTypesAttrName(
result.name)));
870 if (!iteratorTypes) {
872 <<
"expected " << getIteratorTypesAttrName(
result.name)
873 <<
" array attribute";
878 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
879 auto maybeIteratorType = symbolizeIteratorType(s);
880 if (!maybeIteratorType.has_value())
881 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
883 iteratorTypeAttrs.push_back(
884 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
886 result.attributes.set(getIteratorTypesAttrName(
result.name),
889 if (!
result.attributes.get(getKindAttrName(
result.name))) {
891 getKindAttrName(
result.name),
892 CombiningKindAttr::get(
result.getContext(),
893 ContractionOp::getDefaultKind()));
895 if (masksInfo.empty())
897 if (masksInfo.size() != 2)
899 "expected zero or exactly 2 vector mask operands");
900 auto lhsType = llvm::cast<VectorType>(types[0]);
901 auto rhsType = llvm::cast<VectorType>(types[1]);
903 std::array<VectorType, 2> maskTypes = {
913 auto attrNames = getTraitAttrNames();
915 traitAttrsSet.insert_range(attrNames);
917 for (
auto attr : (*this)->getAttrs()) {
918 if (attr.getName() == getIteratorTypesAttrName()) {
920 llvm::cast<ArrayAttr>(attr.getValue())
921 .getAsValueRange<IteratorTypeAttr, IteratorType>();
927 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) ->
Attribute {
928 return StringAttr::get(
getContext(), stringifyIteratorType(t));
931 attrs.emplace_back(getIteratorTypesAttrName(),
932 ArrayAttr::get(
getContext(), iteratorTypeNames));
933 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
935 if (attr.getName() == getFastmathAttrName() &&
936 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
937 arith::FastMathFlags::none)
939 attrs.push_back(attr);
943 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
944 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
945 p << getRhs() <<
", " << getAcc();
948 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
953 const std::vector<std::pair<int64_t, int64_t>> &map) {
954 for (
auto &dimPair : map) {
955 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
956 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
957 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
964 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
966 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
967 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
970 for (
auto &dimPair : contractingDimMap) {
971 lhsContractingDimSet.insert(dimPair.first);
972 rhsContractingDimSet.insert(dimPair.second);
975 llvm::make_second_range(batchDimMap));
979 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
980 if (lhsContractingDimSet.count(i) > 0)
982 expectedResultDims.push_back(lhsType.getDimSize(i));
986 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
987 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
989 expectedResultDims.push_back(rhsType.getDimSize(i));
993 if (expectedResultDims.empty()) {
995 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
996 return op.emitOpError(
"invalid accumulator/result vector shape");
999 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
1000 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
1001 if (!resVectorType || !accVectorType)
1002 return op.emitOpError(
"invalid accumulator/result vector shape");
1008 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1009 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1011 return op.emitOpError(
1012 "expected all dimensions to be either a LHS or a RHS dimension");
1015 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1016 VectorType v = pair.first;
1017 auto map = pair.second;
1018 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1019 unsigned pos = map.getDimPosition(idx);
1024 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1025 return op.emitOpError(
"expected all dimensions to get an extent as "
1026 "either a LHS or a RHS dimension");
1028 AffineMap resMap = op.getIndexingMapsArray()[2];
1033 assert(llvm::all_of(expectedMap.
getResults(),
1034 llvm::IsaPred<AffineConstantExpr>) &&
1035 "expected constant extent along all dimensions.");
1037 auto expectedShape =
1039 return cast<AffineConstantExpr>(e).getValue();
1042 VectorType::get(expectedShape, resVectorType.getElementType(),
1043 resVectorType.getScalableDims());
1044 if (resVectorType != expected || accVectorType != expected)
1045 return op.emitOpError(
1046 "invalid accumulator/result vector shape, expected: ")
1052LogicalResult ContractionOp::verify() {
1053 VectorType lhsType = getLhsType();
1054 VectorType rhsType = getRhsType();
1055 Type accType = getAccType();
1056 Type resType = getResultType();
1058 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1059 if (!lhsType.getElementType().isSignlessInteger())
1060 return emitOpError(
"only supports signless integer types");
1064 if (getIndexingMapsArray().size() != 3)
1065 return emitOpError(
"expected an indexing map for each vector operand");
1070 unsigned numIterators = getIteratorTypes().getValue().size();
1071 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1072 auto index = it.index();
1073 auto map = it.value();
1074 if (map.getNumSymbols() != 0)
1076 <<
index <<
" to have no symbols";
1077 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1078 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1081 if (map.getNumDims() != numIterators)
1083 <<
index <<
" to have " << numIterators <<
" number of inputs";
1084 if (map.getNumResults() != rank)
1086 <<
index <<
" to have " << rank <<
" number of outputs";
1087 if (!map.isProjectedPermutation())
1089 <<
index <<
" to be a projected permutation of its inputs";
1092 auto contractingDimMap = getContractingDimMap();
1093 auto batchDimMap = getBatchDimMap();
1096 if (contractingDimMap.empty())
1097 return emitOpError(
"expected at least one contracting dimension pair");
1100 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1101 return emitOpError(
"invalid contracting dimension map");
1105 return emitOpError(
"invalid batch dimension map");
1109 contractingDimMap, batchDimMap)))
1112 if (!getKindAttr()) {
1113 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1114 "'vector.kind<add>')");
1118 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1119 auto elementType = vectorType ? vectorType.getElementType() : resType;
1121 return emitOpError(
"unsupported contraction type");
1124 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1131Type ContractionOp::getExpectedMaskType() {
1132 auto indexingMaps = this->getIndexingMapsArray();
1135 VectorType lhsType = this->getLhsType();
1136 VectorType rhsType = this->getRhsType();
1138 unsigned numVecDims = lhsIdxMap.
getNumDims();
1144 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1147 lhsType.getScalableDims()[dimIdx];
1149 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1152 rhsType.getScalableDims()[dimIdx];
1155 assert(ShapedType::isStaticShape(maskShape) &&
1156 "Mask shape couldn't be computed");
1158 return VectorType::get(maskShape,
1159 IntegerType::get(lhsType.getContext(), 1),
1160 maskShapeScalableDims);
1165 getIteratorTypesAttrName(), getKindAttrName(),
1166 getFastmathAttrName()};
1176static std::vector<std::pair<int64_t, int64_t>>
1178 IteratorType targetIteratorType,
MLIRContext *context) {
1179 std::vector<std::pair<int64_t, int64_t>> dimMap;
1180 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1181 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1182 if (iteratorType != targetIteratorType)
1188 if (lhsDim >= 0 && rhsDim >= 0)
1189 dimMap.emplace_back(lhsDim, rhsDim);
1194void ContractionOp::getIterationBounds(
1196 auto lhsShape = getLhsType().getShape();
1197 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1199 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1202 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1203 if (iteratorType == IteratorType::reduction) {
1206 assert(lhsDimIndex >= 0);
1207 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1212 assert(resDimIndex >= 0);
1213 assert(resVectorType !=
nullptr);
1214 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1218void ContractionOp::getIterationIndexMap(
1220 unsigned numMaps = getIndexingMapsArray().size();
1221 iterationIndexMap.resize(numMaps);
1222 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1223 auto index = it.index();
1224 auto map = it.value();
1225 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1226 auto dim = cast<AffineDimExpr>(map.getResult(i));
1227 iterationIndexMap[
index][dim.getPosition()] = i;
1232std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1234 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1238std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1240 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1244std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1246 getIterationBounds(
shape);
1268template <
typename AddOpType>
1274 auto canonicalize = [&](
Value maybeContraction,
1275 Value otherOperand) -> vector::ContractionOp {
1276 vector::ContractionOp contractionOp =
1277 dyn_cast_or_null<vector::ContractionOp>(
1280 return vector::ContractionOp();
1281 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1282 contractionOp.getAcc().getDefiningOp())) {
1283 if (maybeZero.getValue() ==
1284 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1286 bvm.
map(contractionOp.getAcc(), otherOperand);
1287 auto newContraction =
1288 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1289 rewriter.
replaceOp(addOp, newContraction.getResult());
1290 return newContraction;
1293 return vector::ContractionOp();
1296 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1297 vector::ContractionOp
contract = canonicalize(a,
b);
1322 setResultRanges(getResult(), argRanges.front());
1327 auto vectorTy = cast<VectorType>(source.
getType());
1352 build(builder,
result, source, dynamicPos,
1357ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1358 ExtractOp::Adaptor adaptor,
1360 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1361 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1362 vectorType.getRank()) {
1363 inferredReturnTypes.push_back(vectorType.getElementType());
1365 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1366 vectorType.getRank());
1367 inferredReturnTypes.push_back(VectorType::get(
1368 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1369 vectorType.getScalableDims().drop_front(n)));
1374LogicalResult vector::ExtractOp::verify() {
1375 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1376 if (resTy.getRank() == 0)
1378 "expected a scalar instead of a 0-d vector as the result type");
1381 auto dynamicMarkersCount =
1382 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1383 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1385 "mismatch between dynamic and static positions (kDynamic marker but no "
1386 "corresponding dynamic position) -- this can only happen due to an "
1387 "incorrect fold/rewrite");
1388 auto position = getMixedPosition();
1389 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1391 "expected position attribute of rank no greater than vector rank");
1392 for (
auto [idx, pos] : llvm::enumerate(position)) {
1393 if (
auto attr = dyn_cast<Attribute>(pos)) {
1394 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1396 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1397 return emitOpError(
"expected position attribute #")
1399 <<
" to be a non-negative integer smaller than the "
1400 "corresponding vector dimension or poison (-1)";
1407template <
typename IntType>
1409 return llvm::map_to_vector<4>(
1410 arrayAttr.getAsRange<IntegerAttr>(),
1411 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1417 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1421 if (extractOp.hasDynamicPosition())
1425 ExtractOp currentOp = extractOp;
1427 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1428 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1431 if (currentOp.hasDynamicPosition())
1434 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1436 extractOp.setOperand(0, currentOp.getSource());
1439 std::reverse(globalPosition.begin(), globalPosition.end());
1440 extractOp.setStaticPosition(globalPosition);
1452class ExtractFromInsertTransposeChainState {
1454 ExtractFromInsertTransposeChainState(ExtractOp e);
1463 template <
typename ContainerA,
typename ContainerB>
1464 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1465 return a.size() <=
b.size() &&
1466 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1473 template <
typename ContainerA,
typename ContainerB>
1474 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1475 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1476 if (elemA < 0 || elemB < 0)
1487 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1491 void updateStateForNextIteration(Value v) {
1498 LogicalResult handleTransposeOp();
1501 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1516 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1521 Value tryToFoldExtractOpInPlace(Value source);
1523 ExtractOp extractOp;
1525 int64_t extractedRank;
1527 InsertOp nextInsertOp;
1528 TransposeOp nextTransposeOp;
1538 SmallVector<int64_t> sentinels;
1539 SmallVector<int64_t> extractPosition;
1543ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1545 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1546 extractedRank(extractOp.getNumIndices()) {
1547 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1548 sentinels.reserve(vectorRank - extractedRank);
1549 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1550 sentinels.push_back(-(i + 1));
1552 extractOp.getStaticPosition().end());
1558LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1560 if (extractOp.hasDynamicPosition())
1563 if (!nextTransposeOp)
1566 nextTransposeOp.getPermutation(), extractOp.getContext()));
1573ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1576 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1579 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1580 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1583 res = nextInsertOp.getValueToStore();
1592ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1594 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1597 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1607 res = nextInsertOp.getValueToStore();
1615Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1618 if (extractOp.hasDynamicPosition())
1622 bool nothingToFold = (source == extractOp.getSource());
1623 if (nothingToFold || !canFold())
1627 OpBuilder
b(extractOp.getContext());
1628 extractOp.setStaticPosition(
1630 extractOp.getSourceMutable().assign(source);
1631 return extractOp.getResult();
1635Value ExtractFromInsertTransposeChainState::fold() {
1637 if (extractOp.hasDynamicPosition())
1640 Value valueToExtractFrom = extractOp.getSource();
1641 updateStateForNextIteration(valueToExtractFrom);
1642 while (nextInsertOp || nextTransposeOp) {
1645 if (succeeded(handleTransposeOp())) {
1646 valueToExtractFrom = nextTransposeOp.getVector();
1647 updateStateForNextIteration(valueToExtractFrom);
1653 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1658 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1659 return tryToFoldExtractOpInPlace(
result);
1663 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1669 valueToExtractFrom = nextInsertOp.getDest();
1670 updateStateForNextIteration(valueToExtractFrom);
1673 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1678 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1679 auto vecType = dyn_cast<VectorType>(type);
1680 return vecType && vecType.getRank() == 0;
1690 if (isa<BroadcastOp>(op))
1693 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1701 VectorType srcType = shapeCast.getSourceVectorType();
1703 uint64_t srcRank = srcType.getRank();
1705 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1731 Operation *defOp = extractOp.getSource().getDefiningOp();
1738 if (extractOp.getType() == input.
getType())
1744 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1745 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1746 unsigned inputRank = inputType ? inputType.getRank() : 0;
1747 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1748 unsigned extractRank = extractType ? extractType.getRank() : 0;
1751 if (extractRank > inputRank)
1755 assert(inputType &&
"input must be a vector type because of previous checks");
1764 extractType.getShape() != inputShape.take_back(extractRank))
1769 unsigned deltaOverall = inputRank - extractRank;
1770 unsigned deltaBroadcast = broadcastRank - inputRank;
1774 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1775 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1778 extractOp->setOperands(
1779 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1780 extractOp.setStaticPosition(staticPos);
1781 return extractOp.getResult();
1797 if (extractOp.hasDynamicPosition())
1800 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1805 if (shuffleOp.getResultVectorType().getRank() != 1)
1808 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1809 auto shuffleMask = shuffleOp.getMask();
1810 int64_t extractIdx = extractOp.getStaticPosition()[0];
1811 int64_t shuffleIdx = shuffleMask[extractIdx];
1814 if (shuffleIdx < inputVecSize) {
1815 extractOp.setOperand(0, shuffleOp.getV1());
1816 extractOp.setStaticPosition({shuffleIdx});
1818 extractOp.setOperand(0, shuffleOp.getV2());
1819 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1822 return extractOp.getResult();
1828 if (extractOp.hasDynamicPosition())
1831 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1836 auto getDimReverse = [](VectorType type,
int64_t n) {
1837 return type.getShape().take_back(n + 1).front();
1840 llvm::isa<VectorType>(extractOp.getType())
1841 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1843 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1845 if (destinationRank > 0) {
1846 auto destinationType =
1847 llvm::cast<VectorType>(extractOp.getResult().getType());
1848 for (
int64_t i = 0; i < destinationRank; i++) {
1852 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1853 getDimReverse(destinationType, i))
1860 std::reverse(extractedPos.begin(), extractedPos.end());
1863 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1864 strides.push_back(stride);
1866 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1874 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1876 for (
int64_t i = 0; i < numDimension; i++) {
1877 newStrides.push_back(stride);
1879 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1881 std::reverse(newStrides.begin(), newStrides.end());
1885 extractOp.setStaticPosition(newPosition);
1886 extractOp.setOperand(0, shapeCastOp.getSource());
1887 return extractOp.getResult();
1893 if (extractOp.hasDynamicPosition())
1896 auto extractStridedSliceOp =
1897 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1898 if (!extractStridedSliceOp)
1907 if (extractStridedSliceOp.hasNonUnitStrides())
1913 while (!sliceOffsets.empty()) {
1914 size_t lastOffset = sliceOffsets.size() - 1;
1915 if (sliceOffsets.back() != 0 ||
1916 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1917 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1919 sliceOffsets.pop_back();
1921 unsigned destinationRank = 0;
1922 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1923 destinationRank = vecType.getRank();
1926 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1927 sliceOffsets.size())
1931 assert(extractedPos.size() >= sliceOffsets.size());
1932 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1933 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1934 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1938 extractOp.setStaticPosition(extractedPos);
1939 return extractOp.getResult();
1945 if (extractOp.hasDynamicPosition())
1949 llvm::isa<VectorType>(extractOp.getType())
1950 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1952 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1962 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1963 insertOp.getSourceVectorType().getRank();
1964 if (destinationRank > insertOp.getSourceVectorType().getRank())
1969 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1970 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1973 bool disjoint =
false;
1975 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1976 int64_t start = insertOffsets[dim];
1978 (dim < insertRankDiff)
1980 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1982 int64_t offset = extractOffsets[dim];
1984 if (start <= offset && offset < end) {
1985 if (dim >= insertRankDiff)
1986 offsetDiffs.push_back(offset - start);
1997 insertOp.getSourceVectorType().getRank() - destinationRank;
1998 for (
int64_t i = 0; i < destinationRank; i++) {
1999 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
2000 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2004 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2007 extractOp.setStaticPosition(offsetDiffs);
2008 return extractOp.getResult();
2012 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2025 if (extractOp.hasDynamicPosition())
2029 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2030 if (!fromElementsOp)
2034 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2035 if (vecType.isScalable())
2039 int64_t rank = vecType.getRank();
2041 if (extractOp.getType() != vecType.getElementType())
2044 "unexpected number of indices");
2049 for (
int i = rank - 1; i >= 0; --i) {
2050 flatIndex +=
indices[i] * stride;
2051 stride *= vecType.getDimSize(i);
2053 return fromElementsOp.getElements()[flatIndex];
2058template <
typename OpType,
typename AdaptorType>
2061 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2062 OperandRange dynamicPosition = op.getDynamicPosition();
2065 if constexpr (std::is_same_v<OpType, ExtractOp>)
2066 vectorShape = op.getSourceVectorType().getShape();
2071 if (!dynamicPosition.size())
2078 bool opChange =
false;
2079 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2080 if (ShapedType::isStatic(staticPosition[i]))
2084 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2085 int64_t value = attr.getInt();
2089 staticPosition[i] = attr.getInt();
2094 operands.push_back(position);
2098 op.setStaticPosition(staticPosition);
2099 op.getOperation()->setOperands(operands);
2101 return op.getResult();
2111 if (!is_contained(staticPos, poisonVal))
2114 return ub::PoisonAttr::get(context);
2128 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2133 if (denseAttr.isSplat()) {
2135 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2140 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2141 if (vecTy.isScalable())
2144 if (extractOp.hasDynamicPosition()) {
2159 copy(extractOp.getStaticPosition(), completePositions.begin());
2162 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2165 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2167 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2170 newAttr = *denseValuesBegin;
2176OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2180 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2187 SmallVector<Value> operands = {getSource()};
2191 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2197 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2212 return inplaceFolded;
2218class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2222 LogicalResult matchAndRewrite(ExtractOp extractOp,
2223 PatternRewriter &rewriter)
const override {
2226 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2232 BroadcastableToResult::Success)
2241class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2245 LogicalResult matchAndRewrite(ExtractOp extractOp,
2246 PatternRewriter &rewriter)
const override {
2248 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2252 VectorType extractedMaskType =
2253 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2255 if (!extractedMaskType)
2258 auto maskOperands = createMaskOp.getOperands();
2259 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2260 VectorType maskType = createMaskOp.getVectorType();
2262 bool containsUnknownDims =
false;
2265 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2267 int64_t pos = extractOpPos[dimIdx];
2268 Value operand = maskOperands[dimIdx];
2269 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2272 containsUnknownDims =
true;
2276 int64_t createMaskBound =
2277 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2279 if (pos != ShapedType::kDynamic) {
2282 allFalse |= pos >= createMaskBound;
2283 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2287 containsUnknownDims =
true;
2294 }
else if (!containsUnknownDims) {
2296 extractOp, extractedMaskType,
2297 maskOperands.drop_front(extractOpPos.size()));
2306class ExtractOpFromConstantMask final :
public OpRewritePattern<ExtractOp> {
2310 LogicalResult matchAndRewrite(ExtractOp extractOp,
2311 PatternRewriter &rewriter)
const override {
2312 auto constantMaskOp =
2313 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2314 if (!constantMaskOp)
2317 Type resultType = extractOp.getResult().getType();
2318 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2320 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2321 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2323 VectorType maskType = constantMaskOp.getVectorType();
2326 for (
size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2327 int64_t pos = extractOpPos[dimIdx];
2328 if (pos == ShapedType::kDynamic) {
2331 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2340 if (pos >= maskDimSizes[dimIdx]) {
2341 if (extractedMaskType) {
2353 if (extractedMaskType) {
2357 extractOp, extractedMaskType,
2358 maskDimSizes.drop_front(extractOpPos.size()));
2371LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2372 PatternRewriter &rewriter) {
2373 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2377 VectorType sourceType = castOp.getSourceVectorType();
2378 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2382 if (sourceType.getNumElements() != targetType.getNumElements())
2386 castOp.getSource());
2396LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2397 PatternRewriter &rewriter) {
2399 if (extractOp.hasDynamicPosition())
2403 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2408 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2409 if (!fromElementsOp)
2411 VectorType inputType = fromElementsOp.getType();
2414 if (resultType.isScalable() || inputType.isScalable())
2419 SmallVector<int64_t> firstElementPos =
2420 llvm::to_vector(extractOp.getStaticPosition());
2421 firstElementPos.append(resultType.getRank(), 0);
2424 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2425 flatIndex += firstElementPos[i] * stride;
2426 stride *= inputType.getDimSize(i);
2431 extractOp, resultType,
2432 fromElementsOp.getElements().slice(flatIndex,
2433 resultType.getNumElements()));
2445struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2447 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2448 PatternRewriter &rewriter)
const override {
2449 VectorType sourceType = extractOp.getSourceVectorType();
2450 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2454 if (sourceType.getNumElements() != outType.getNumElements())
2456 extractOp,
"extract to vector with fewer elements");
2460 if (llvm::any_of(extractOp.getMixedPosition(),
2461 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2463 "leaving for extract poison folder");
2466 extractOp.getSource());
2487struct FoldExtractFromInsertUnitDim final
2488 : OpRewritePattern<vector::ExtractOp> {
2491 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2492 PatternRewriter &rewriter)
const override {
2493 if (extractOp.hasDynamicPosition())
2496 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2497 if (!insertOp || insertOp.hasDynamicPosition())
2500 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2501 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2504 if (extractPos.size() >= insertPos.size() ||
2505 extractPos != insertPos.take_front(extractPos.size()))
2511 auto srcVecType = extractOp.getSourceVectorType();
2512 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2513 if (srcVecType.getDimSize(i) != 1)
2516 Value
inserted = insertOp.getValueToStore();
2517 Type extractedType = extractOp.getResult().getType();
2518 if (isa<VectorType>(
inserted.getType())) {
2525 extractOp, extractOp.getResult().
getType(),
2526 insertOp.getValueToStore());
2534void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2535 MLIRContext *context) {
2536 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2537 ExtractOpFromConstantMask, ExtractToShapeCast,
2538 FoldExtractFromInsertUnitDim>(context);
2539 results.
add(foldExtractFromShapeCastToShapeCast);
2540 results.
add(foldExtractFromFromElements);
2545 for (
auto attr : arrayAttr)
2546 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2553std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2564 if (operands.empty())
2567 return llvm::all_of(operands, [&](
Value operand) {
2569 return currentDef == defOp;
2587 auto fromElementsOp =
2588 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2589 if (!fromElementsOp)
2592 llvm::append_range(results, fromElementsOp.getElements());
2609 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2613 if (isa<VectorType>(bcastOp.getSource().getType()))
2616 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2618 Value scalar = bcastOp.getSource();
2619 results.assign(resultVecType.getNumElements(), scalar);
2623LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2624 SmallVectorImpl<OpFoldResult> &results) {
2629 if (
auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2630 setOperand(shapeCast.getSource());
2638ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2639 ToElementsOp::Adaptor adaptor,
2640 SmallVectorImpl<Type> &inferredReturnTypes) {
2641 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2642 Type elType = vecType.getElementType();
2643 inferredReturnTypes.append(vecType.getNumElements(), elType);
2665 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2670 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2674 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2679 int64_t dstRank = dstShape.size();
2680 int64_t srcRank = srcShape.size();
2683 auto srcElems = vector::ToElementsOp::create(
2684 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2686 int64_t dstCount = llvm::product_of(dstShape);
2689 replacements.reserve(dstCount);
2714 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2717 for (
int64_t k = 0; k < srcRank; ++k)
2718 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2721 replacements.push_back(srcElems.getResult(srcLin));
2724 rewriter.
replaceOp(toElementsOp, replacements);
2729void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2730 MLIRContext *context) {
2731 results.
add<ToElementsOfBroadcast>(context);
2751 OperandRange fromElemsOperands = fromElementsOp.getElements();
2752 if (fromElemsOperands.empty())
2755 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2763 Value toElementsInput = toElementsOp.getSource();
2764 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2765 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2766 return toElementsInput;
2786 if (llvm::any_of(elements, [](
Attribute attr) {
2792 auto destVecType = fromElementsOp.getDest().getType();
2793 auto destEltType = destVecType.getElementType();
2794 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2799 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2806OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2823 if (!llvm::all_equal(fromElementsOp.getElements()))
2826 fromElementsOp, fromElementsOp.getType(),
2827 fromElementsOp.getElements().front());
2855 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2859 if (fromElements.getType().getNumElements() == 1)
2870 for (
auto [insertIndex, element] :
2871 llvm::enumerate(fromElements.getElements())) {
2874 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2877 "element not from vector.extract");
2882 if (insertIndex == 0) {
2883 source = extractOp.getSource();
2884 }
else if (extractOp.getSource() != source) {
2886 "element from different vector");
2890 int64_t rank = position.size();
2891 assert(rank == source.getType().getRank() &&
2892 "scalar extract must have full rank position");
2903 if (insertIndex == 0) {
2904 const int64_t numElms = fromElements.getType().getNumElements();
2907 while (
index > 0 && position[
index - 1] == 0 &&
2908 numSuffixElms < numElms) {
2909 numSuffixElms *= source.getType().getDimSize(
index - 1);
2912 if (numSuffixElms != numElms) {
2914 fromElements,
"elements do not form a suffix of source");
2916 expectedPosition = llvm::to_vector(position);
2917 combinedPosition = position.drop_back(rank -
index);
2921 else if (expectedPosition != position) {
2923 fromElements,
"elements not in ascending order (static order)");
2925 increment(expectedPosition, source.getType().getShape());
2928 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2929 fromElements.getLoc(), source, combinedPosition);
2932 fromElements, fromElements.getType(), extracted);
2940 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2959void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2961 setResultRanges(getResult(), argRanges.front());
2964std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2965 return llvm::to_vector<4>(getResultVectorType().
getShape());
2970static llvm::SetVector<int64_t>
2973 int64_t rankDiff = dstShape.size() - srcShape.size();
2976 for (
auto [s1, s2] :
2977 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2979 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2987llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2989 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2992 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
3008Value BroadcastOp::createOrFoldBroadcastOp(
3009 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
3010 const llvm::SetVector<int64_t> &broadcastedDims) {
3011 assert(!dstShape.empty() &&
"unexpected empty dst shape");
3014 SmallVector<int64_t> checkShape;
3015 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
3016 if (broadcastedDims.contains(i))
3018 checkShape.push_back(dstShape[i]);
3020 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3021 "ill-formed broadcastedDims contains values not confined to "
3024 Location loc = value.
getLoc();
3026 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
3027 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3030 if (!srcVectorType) {
3031 assert(checkShape.empty() &&
3032 "ill-formed createOrFoldBroadcastOp arguments");
3033 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3036 assert(srcVectorType.getShape().equals(checkShape) &&
3037 "ill-formed createOrFoldBroadcastOp arguments");
3047 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3048 broadcastShape.reserve(dstShape.size());
3064 int64_t nextSrcShapeDim = broadcastedDims.size();
3065 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3066 if (broadcastedDims.contains(i)) {
3071 broadcastShape.push_back(dstShape[i]);
3072 permutation[i] = broadcastShape.size() - 1;
3078 permutation[i] = nextSrcShapeDim++;
3082 llvm::append_range(broadcastShape, srcVectorType.getShape());
3087 "unexpected \"dim-1\" broadcast");
3089 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3091 vector::BroadcastableToResult::Success &&
3092 "must be broadcastable");
3093 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3096 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3097 if (permutation[i] != i)
3098 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3104 Type srcType, VectorType dstVectorType,
3105 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3107 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3111 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3115 int64_t srcRank = srcVectorType.getRank();
3116 int64_t dstRank = dstVectorType.getRank();
3117 if (srcRank > dstRank)
3121 int64_t lead = dstRank - srcRank;
3122 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3125 bool foundMismatchingDims =
false;
3128 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3129 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3130 if (srcDim != 1 && srcDim != dstDim)
3131 foundMismatchingDims =
true;
3134 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3135 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3136 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3139 (srcDimScalableFlag != dstDimScalableFlag &&
3140 (srcDim != 1 || srcDimScalableFlag)))
3141 foundMismatchingDims =
true;
3143 if (foundMismatchingDims) {
3144 if (mismatchingDims !=
nullptr) {
3145 mismatchingDims->first.dim = srcDim;
3146 mismatchingDims->first.isScalable = srcDimScalableFlag;
3148 mismatchingDims->second.dim = dstDim;
3149 mismatchingDims->second.isScalable = dstDimScalableFlag;
3158LogicalResult BroadcastOp::verify() {
3159 std::pair<VectorDim, VectorDim> mismatchingDims;
3161 getSourceType(), getResultVectorType(), &mismatchingDims);
3165 return emitOpError(
"source rank higher than destination rank");
3168 << (mismatchingDims.first.isScalable ?
"[" :
"")
3169 << mismatchingDims.first.dim
3170 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
3171 << (mismatchingDims.second.isScalable ?
"[" :
"")
3172 << mismatchingDims.second.dim
3173 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3176 return emitOpError(
"source type is not a vector");
3177 llvm_unreachable(
"unexpected vector.broadcast op error");
3184 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3188 VectorType srcType = srcShapeCast.getSourceVectorType();
3189 VectorType destType = broadcastOp.getResultVectorType();
3197 srcShapeCast.getResultVectorType().getShape();
3200 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3201 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3202 shapecastShape.take_back(numTrailingDims)))
3205 assert(all_of(srcShape.drop_back(numTrailingDims),
3206 [](
int64_t E) { return E == 1; }) &&
3207 all_of(shapecastShape.drop_back(numTrailingDims),
3208 [](
int64_t E) { return E == 1; }) &&
3209 "ill-formed shape_cast");
3211 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3215OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3216 if (getSourceType() == getResultVectorType())
3221 if (!adaptor.getSource())
3223 auto vectorType = getResultVectorType();
3224 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3225 if (vectorType.getElementType() != attr.getType())
3229 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3230 if (vectorType.getElementType() != attr.getType())
3234 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3244struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3247 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3248 PatternRewriter &rewriter)
const override {
3249 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3253 broadcastOp.getResultVectorType(),
3254 srcBroadcast.getSource());
3267struct BroadcastToShapeCast final
3268 :
public OpRewritePattern<vector::BroadcastOp> {
3270 LogicalResult matchAndRewrite(vector::BroadcastOp
broadcast,
3271 PatternRewriter &rewriter)
const override {
3273 auto sourceType = dyn_cast<VectorType>(
broadcast.getSourceType());
3276 broadcast,
"source is a scalar, shape_cast doesn't support scalar");
3280 if (sourceType.getNumElements() != outType.getNumElements()) {
3282 broadcast,
"broadcast to a greater number of elements");
3292void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3293 MLIRContext *context) {
3294 results.
add<BroadcastFolder, BroadcastToShapeCast>(context);
3301LogicalResult ShuffleOp::verify() {
3302 VectorType resultType = getResultVectorType();
3303 VectorType v1Type = getV1VectorType();
3304 VectorType v2Type = getV2VectorType();
3306 int64_t resRank = resultType.getRank();
3307 int64_t v1Rank = v1Type.getRank();
3308 int64_t v2Rank = v2Type.getRank();
3309 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3310 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3311 if (!wellFormed0DCase && !wellFormedNDCase)
3315 for (int64_t r = 1; r < v1Rank; ++r) {
3316 int64_t resDim = resultType.getDimSize(r);
3317 int64_t v1Dim = v1Type.getDimSize(r);
3318 int64_t v2Dim = v2Type.getDimSize(r);
3319 if (resDim != v1Dim || v1Dim != v2Dim)
3323 ArrayRef<int64_t> mask = getMask();
3324 int64_t maskLength = mask.size();
3325 if (maskLength <= 0)
3327 if (maskLength != resultType.getDimSize(0))
3330 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3331 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3332 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3334 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3340ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3341 ShuffleOp::Adaptor adaptor,
3342 SmallVectorImpl<Type> &inferredReturnTypes) {
3343 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3347 auto v1Rank = v1Type.getRank();
3350 SmallVector<int64_t, 4> shape;
3351 shape.reserve(v1Rank);
3352 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3355 llvm::append_range(shape, v1Type.getShape().drop_front());
3356 inferredReturnTypes.push_back(
3357 VectorType::get(shape, v1Type.getElementType()));
3361template <
typename T>
3364 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3365 return value == expected++;
3372 auto v1Type = op.getV1VectorType();
3373 auto v2Type = op.getV2VectorType();
3374 auto mask = op.getMask();
3387 if (!isV1Poison && !isV2Poison)
3390 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3391 bool changed =
false;
3393 for (
int64_t &idx : newMask) {
3394 if (idx == ShuffleOp::kPoisonIndex)
3396 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3397 idx = ShuffleOp::kPoisonIndex;
3405 op.setMask(newMask);
3406 return op.getResult();
3415 return ub::PoisonAttr::get(context);
3422 auto v1Type = op.getV1VectorType();
3423 if (v1Type.getRank() != 1)
3435 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3438 v2Elements = to_vector(v2DenseAttr.getValues<
Attribute>());
3439 poisonElement = v2Elements[0];
3442 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3445 v1Elements = to_vector(v1DenseAttr.getValues<
Attribute>());
3446 poisonElement = v1Elements[0];
3451 int64_t v1Size = v1Type.getDimSize(0);
3452 for (
int64_t maskIdx : mask) {
3455 if (maskIdx == ShuffleOp::kPoisonIndex) {
3456 indexedElm = poisonElement;
3458 if (maskIdx < v1Size)
3459 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3461 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3464 results.push_back(indexedElm);
3470OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3471 auto v1Type = getV1VectorType();
3473 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3474 "Vector shuffle does not support scalable vectors");
3478 if (v1Type.getRank() == 0)
3486 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3487 if (!v1Attr || !v2Attr)
3502struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3505 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3506 PatternRewriter &rewriter)
const override {
3507 VectorType v1VectorType = shuffleOp.getV1VectorType();
3508 ArrayRef<int64_t> mask = shuffleOp.getMask();
3509 if (v1VectorType.getRank() > 0)
3511 if (mask.size() != 1)
3513 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3531static Value getScalarSplatSource(Value value) {
3537 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3544 if (isa<VectorType>(
broadcast.getSourceType()))
3552class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3556 LogicalResult matchAndRewrite(ShuffleOp op,
3557 PatternRewriter &rewriter)
const override {
3558 Value splat = getScalarSplatSource(op.getV1());
3559 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3569class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3573 LogicalResult matchAndRewrite(ShuffleOp op,
3574 PatternRewriter &rewriter)
const override {
3575 VectorType resultType = op.getResultVectorType();
3576 if (resultType.isScalable())
3578 op,
"ShuffleOp can't represent a scalable interleave");
3580 if (resultType.getRank() != 1)
3582 op,
"ShuffleOp can't represent an n-D interleave");
3584 VectorType sourceType = op.getV1VectorType();
3585 if (sourceType != op.getV2VectorType() ||
3586 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3588 op,
"ShuffleOp types don't match an interleave");
3591 ArrayRef<int64_t> shuffleMask = op.getMask();
3592 int64_t resultVectorSize = resultType.getNumElements();
3593 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3594 int64_t maskValueA = shuffleMask[i * 2];
3595 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3596 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3598 "ShuffleOp mask not interleaving");
3614class FoldUnusedShuffleOperand final :
public OpRewritePattern<ShuffleOp> {
3618 LogicalResult matchAndRewrite(ShuffleOp op,
3619 PatternRewriter &rewriter)
const override {
3621 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3622 return mask == ShuffleOp::kPoisonIndex;
3629 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3632 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3641 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3642 ? op.getV1VectorType().getDimSize(0)
3644 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3645 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3647 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3651 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3652 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3654 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3662void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3663 MLIRContext *context) {
3664 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3665 FoldUnusedShuffleOperand>(context);
3672void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3674 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3677void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3678 Value source, Value dest) {
3679 auto vectorTy = cast<VectorType>(dest.
getType());
3680 build(builder,
result, source, dest,
3681 SmallVector<int64_t>(vectorTy.getRank(), 0));
3684void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3685 Value source, Value dest, int64_t position) {
3686 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3689void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3690 Value source, Value dest, OpFoldResult position) {
3691 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3694void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3695 Value source, Value dest,
3696 ArrayRef<int64_t> position) {
3697 SmallVector<OpFoldResult> posVals;
3698 posVals.reserve(position.size());
3699 llvm::transform(position, std::back_inserter(posVals),
3701 build(builder,
result, source, dest, posVals);
3704void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3705 Value source, Value dest,
3706 ArrayRef<OpFoldResult> position) {
3707 SmallVector<int64_t> staticPos;
3708 SmallVector<Value> dynamicPos;
3710 build(builder,
result, source, dest, dynamicPos,
3714LogicalResult InsertOp::verify() {
3715 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3716 if (srcTy.getRank() == 0)
3718 "expected a scalar instead of a 0-d vector as the source operand");
3720 SmallVector<OpFoldResult> position = getMixedPosition();
3721 auto destVectorType = getDestVectorType();
3722 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3724 "expected position attribute of rank no greater than dest vector rank");
3725 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3726 if (srcVectorType &&
3727 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3728 static_cast<unsigned>(destVectorType.getRank())))
3729 return emitOpError(
"expected position attribute rank + source rank to "
3730 "match dest vector rank");
3731 if (!srcVectorType &&
3732 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3734 "expected position attribute rank to match the dest vector rank");
3735 for (
auto [idx, pos] : llvm::enumerate(position)) {
3736 if (
auto attr = dyn_cast<Attribute>(pos)) {
3737 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3739 destVectorType.getDimSize(idx))) {
3740 return emitOpError(
"expected position attribute #")
3742 <<
" to be a non-negative integer smaller than the "
3744 "dest vector dimension";
3757 assert(positions.size() <= completePositions.size() &&
3758 "positions size must be less than or equal to destTy rank");
3759 copy(positions, completePositions.begin());
3767class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3771 LogicalResult matchAndRewrite(InsertOp insertOp,
3772 PatternRewriter &rewriter)
const override {
3774 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3775 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3776 srcVecType.getNumElements())
3779 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3785class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3789 LogicalResult matchAndRewrite(InsertOp op,
3790 PatternRewriter &rewriter)
const override {
3792 Value splat = getScalarSplatSource(op.getValueToStore());
3793 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3821class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3824 LogicalResult matchAndRewrite(InsertOp op,
3825 PatternRewriter &rewriter)
const override {
3827 VectorType destTy = op.getDestVectorType();
3828 if (destTy.isScalable())
3831 for (Operation *user : op.getResult().getUsers())
3832 if (
auto insertOp = dyn_cast<InsertOp>(user))
3833 if (insertOp.getDest() == op.getResult())
3836 InsertOp currentOp = op;
3837 SmallVector<InsertOp> chainInsertOps;
3840 if (currentOp.hasDynamicPosition())
3843 chainInsertOps.push_back(currentOp);
3844 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3847 if (currentOp && !currentOp->hasOneUse())
3851 int64_t vectorSize = destTy.getNumElements();
3852 int64_t initializedCount = 0;
3853 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3854 SmallVector<int64_t> pendingInsertPos;
3855 SmallVector<int64_t> pendingInsertSize;
3856 SmallVector<Value> pendingInsertValues;
3858 for (
auto insertOp : chainInsertOps) {
3860 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3864 int64_t insertBeginPosition =
3869 int64_t insertSize = 1;
3870 if (
auto srcVectorType =
3871 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3872 insertSize = srcVectorType.getNumElements();
3874 assert(insertBeginPosition + insertSize <= vectorSize &&
3875 "insert would overflow the vector");
3877 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3878 insertBeginPosition + insertSize)) {
3879 if (initializedDestIdxs[index])
3881 initializedDestIdxs[index] =
true;
3887 pendingInsertPos.push_back(insertBeginPosition);
3888 pendingInsertSize.push_back(insertSize);
3889 pendingInsertValues.push_back(insertOp.getValueToStore());
3891 if (initializedCount == vectorSize)
3896 if (initializedCount != vectorSize)
3899 SmallVector<Value> elements(vectorSize);
3900 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3901 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3902 pendingInsertValues))) {
3903 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3905 if (!srcVectorType) {
3906 elements[insertBeginPosition] = valueToStore;
3910 Repeated<Type> elementToInsertTypes(insertSize,
3911 srcVectorType.getElementType());
3913 auto elementsToInsert = vector::ToElementsOp::create(
3914 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3915 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3916 elements[insertBeginPosition + linearIdx] =
3917 elementsToInsert.getResult(linearIdx);
3931 int64_t maxVectorSizeFoldThreshold) {
3932 if (insertOp.hasDynamicPosition())
3935 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3943 VectorType destTy = insertOp.getDestVectorType();
3944 if (destTy.isScalable())
3948 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3949 !insertOp->hasOneUse())
3954 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3961 Type destEltType = destTy.getElementType();
3965 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3966 for (
auto value : denseSource.getValues<
Attribute>())
3972 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3973 copy(insertedValues, allValues.begin() + insertBeginPosition);
3982 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3986 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3989 insertOp.
setOperand(1, destInsert.getDest());
3990 return insertOp.getResult();
3993void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3994 MLIRContext *context) {
3995 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3996 InsertChainFullyInitialized>(context);
3999OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
4002 constexpr int64_t vectorSizeFoldThreshold = 256;
4006 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
4007 return getValueToStore();
4011 SmallVector<Value> operands = {getValueToStore(), getDest()};
4017 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4020 *
this, adaptor.getValueToStore(), adaptor.getDest(),
4021 vectorSizeFoldThreshold)) {
4025 return inplaceFolded;
4032void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4033 Value source, Value dest,
4034 ArrayRef<int64_t> offsets,
4035 ArrayRef<int64_t> strides) {
4036 result.addOperands({source, dest});
4040 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
4042 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
4047template <
typename OpType>
4051 StringRef attrName) {
4052 if (arrayAttr.size() >
shape.size())
4053 return op.emitOpError(
"expected ")
4054 << attrName <<
" attribute of rank no greater than vector rank";
4061template <
typename OpType>
4065 bool halfOpen =
true) {
4066 for (
auto attr : arrayAttr) {
4067 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4071 if (val < min || val >= upper)
4072 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
4073 <<
min <<
", " << upper <<
")";
4081template <
typename OpType>
4086 for (
auto [
index, attrDimPair] :
4087 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
4088 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4092 if (val < min || val >=
max)
4093 return op.emitOpError(
"expected ")
4094 << attrName <<
" dimension " <<
index <<
" to be confined to ["
4095 <<
min <<
", " <<
max <<
")";
4105template <
typename OpType>
4110 assert(arrayAttr1.size() <=
shape.size());
4111 assert(arrayAttr2.size() <=
shape.size());
4112 for (
auto [
index, it] :
4113 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
4114 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4115 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4119 if (val1 + val2 < 0 || val1 + val2 >=
max)
4120 return op.emitOpError(
"expected sum(")
4121 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
4122 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
4130 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4132 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4135LogicalResult InsertStridedSliceOp::verify() {
4136 auto sourceVectorType = getSourceVectorType();
4137 auto destVectorType = getDestVectorType();
4138 auto offsets = getOffsetsAttr();
4139 auto strides = getStridesAttr();
4140 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
4142 "expected offsets of same size as destination vector rank");
4143 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
4144 return emitOpError(
"expected strides of same size as source vector rank");
4145 if (sourceVectorType.getRank() > destVectorType.getRank())
4147 "expected source rank to be no greater than destination rank");
4149 auto sourceShape = sourceVectorType.getShape();
4150 auto destShape = destVectorType.getShape();
4151 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4152 destShape.size() - sourceShape.size(), 0);
4153 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4154 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4155 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4164 offName,
"source vector shape",
4168 unsigned rankDiff = destShape.size() - sourceShape.size();
4169 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4170 if (sourceVectorType.getScalableDims()[idx] !=
4171 destVectorType.getScalableDims()[idx + rankDiff]) {
4172 return emitOpError(
"mismatching scalable flags (at source vector idx=")
4175 if (sourceVectorType.getScalableDims()[idx]) {
4176 auto sourceSize = sourceShape[idx];
4177 auto destSize = destShape[idx + rankDiff];
4178 if (sourceSize != destSize) {
4181 << (
" to match the corresponding base size from the input "
4183 << sourceSize << (
" vs ") << destSize << (
")");
4193class FoldInsertStridedSliceSplat final
4194 :
public OpRewritePattern<InsertStridedSliceOp> {
4198 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4199 PatternRewriter &rewriter)
const override {
4201 auto dst = insertStridedSliceOp.getDest();
4202 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4203 if (!splat || getScalarSplatSource(dst) != splat)
4206 rewriter.
replaceOp(insertStridedSliceOp, dst);
4213class FoldInsertStridedSliceOfExtract final
4214 :
public OpRewritePattern<InsertStridedSliceOp> {
4218 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4219 PatternRewriter &rewriter)
const override {
4220 auto extractStridedSliceOp =
4221 insertStridedSliceOp.getValueToStore()
4222 .getDefiningOp<vector::ExtractStridedSliceOp>();
4224 if (!extractStridedSliceOp)
4227 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4231 if (extractStridedSliceOp.getStrides() !=
4232 insertStridedSliceOp.getStrides() ||
4233 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4236 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4243class InsertStridedSliceConstantFolder final
4244 :
public OpRewritePattern<InsertStridedSliceOp> {
4250 static constexpr int64_t vectorSizeFoldThreshold = 256;
4252 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4253 PatternRewriter &rewriter)
const override {
4257 Attribute vectorDestCst;
4261 VectorType destTy = destVector.getType();
4262 if (destTy.isScalable())
4266 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4267 !destVector.hasOneUse())
4271 Attribute sourceCst;
4281 if (op.hasNonUnitStrides())
4284 VectorType sliceVecTy = sourceValue.getType();
4285 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4286 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4287 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
4288 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
4296 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4297 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4298 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4299 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4300 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4301 MutableArrayRef<int64_t> currSlicePosition(
4302 currDestPosition.begin() + rankDifference, currDestPosition.end());
4303 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4306 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
4307 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
4308 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4309 "Invalid slice element");
4310 newValues[linearizedPosition] = *sliceValuesIt;
4323void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4324 RewritePatternSet &results, MLIRContext *context) {
4325 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4326 InsertStridedSliceConstantFolder>(context);
4329OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4330 if (getSourceVectorType() == getDestVectorType())
4331 return getValueToStore();
4340void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4341 Value
lhs, Value
rhs, Value acc) {
4346void OuterProductOp::print(OpAsmPrinter &p) {
4347 p <<
" " << getLhs() <<
", " << getRhs();
4349 p <<
", " << getAcc();
4352 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4355ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4356 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4363 if (operandsInfo.size() < 2)
4365 "expected at least 2 operands");
4366 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4367 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4370 "expected vector type for operand #1");
4374 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4375 vRHS.getScalableDims()[0]};
4376 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4377 vLHS.getElementType(), scalableDimsRes);
4380 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4381 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4385 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4386 result.attributes.append(
4387 OuterProductOp::getKindAttrName(
result.name),
4388 CombiningKindAttr::get(
result.getContext(),
4389 OuterProductOp::getDefaultKind()));
4395 (operandsInfo.size() > 2 &&
4400LogicalResult OuterProductOp::verify() {
4401 Type tRHS = getOperandTypeRHS();
4402 VectorType vLHS = getOperandVectorTypeLHS(),
4403 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4404 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4406 if (vLHS.getRank() != 1)
4407 return emitOpError(
"expected 1-d vector for operand #1");
4411 if (vRHS.getRank() != 1)
4412 return emitOpError(
"expected 1-d vector for operand #2");
4413 if (vRES.getRank() != 2)
4415 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4416 return emitOpError(
"expected #1 operand dim to match result dim #1");
4417 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4418 return emitOpError(
"expected #2 operand dim to match result dim #2");
4419 if (vLHS.isScalable() && !vRHS.isScalable()) {
4423 "expected either both or only #2 operand dim to be scalable");
4427 if (vRES.getRank() != 1)
4429 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4430 return emitOpError(
"expected #1 operand dim to match result dim #1");
4433 if (vACC && vACC != vRES)
4434 return emitOpError(
"expected operand #3 of same type as result type");
4436 if (!getKindAttr()) {
4437 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4438 "'vector.kind<add>')");
4443 return emitOpError(
"unsupported outerproduct type");
4452Type OuterProductOp::getExpectedMaskType() {
4453 auto vecType = this->getResultVectorType();
4454 return VectorType::get(vecType.getShape(),
4455 IntegerType::get(vecType.getContext(), 1),
4456 vecType.getScalableDims());
4470 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4472 shape.reserve(vectorType.getRank());
4474 for (
unsigned e = offsets.size(); idx < e; ++idx)
4475 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4476 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4477 shape.push_back(vectorType.getShape()[idx]);
4479 return VectorType::get(
shape, vectorType.getElementType(),
4480 vectorType.getScalableDims());
4483void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4484 Value source, ArrayRef<int64_t> offsets,
4485 ArrayRef<int64_t> sizes,
4486 ArrayRef<int64_t> strides) {
4487 result.addOperands(source);
4493 offsetsAttr, sizesAttr, stridesAttr));
4494 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4496 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4498 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4502LogicalResult ExtractStridedSliceOp::verify() {
4503 auto type = getSourceVectorType();
4504 auto offsets = getOffsetsAttr();
4505 auto sizes = getSizesAttr();
4506 auto strides = getStridesAttr();
4507 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4509 "expected offsets, sizes and strides attributes of same size");
4511 auto shape = type.getShape();
4512 auto offName = getOffsetsAttrName();
4513 auto sizesName = getSizesAttrName();
4514 auto stridesName = getStridesAttrName();
4530 shape, offName, sizesName,
4535 offsets, sizes, strides);
4536 if (getResult().
getType() != resultType)
4537 return emitOpError(
"expected result type to be ") << resultType;
4539 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4540 if (type.getScalableDims()[idx]) {
4541 auto inputDim = type.getShape()[idx];
4542 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4543 if (inputDim != inputSize)
4546 << (
" to match the corresponding base size from the input "
4548 << inputSize << (
" vs ") << inputDim << (
")");
4561 auto getElement = [](
ArrayAttr array,
int idx) {
4562 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4564 ArrayAttr extractOffsets = op.getOffsets();
4567 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4569 if (op.getSourceVectorType().getRank() !=
4570 insertOp.getSourceVectorType().getRank())
4572 ArrayAttr insertOffsets = insertOp.getOffsets();
4573 ArrayAttr insertStrides = insertOp.getStrides();
4576 if (extractOffsets.size() > insertOffsets.size())
4578 bool patialoverlap =
false;
4579 bool disjoint =
false;
4581 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4582 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4584 int64_t start = getElement(insertOffsets, dim);
4585 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4586 int64_t offset = getElement(extractOffsets, dim);
4587 int64_t size = getElement(extractSizes, dim);
4589 if (start <= offset && offset < end) {
4592 if (offset + size > end)
4593 patialoverlap =
true;
4594 offsetDiffs.push_back(offset - start);
4601 if (!disjoint && !patialoverlap) {
4602 op.setOperand(insertOp.getValueToStore());
4605 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4611 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4626 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4631 if (op.hasNonUnitStrides())
4634 VectorType sourceVecTy = op.getSourceVectorType();
4638 VectorType sliceVecTy = op.getType();
4640 int64_t rank = sliceVecTy.getRank();
4652 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4654 sliceValues.reserve(sliceVecTy.getNumElements());
4658 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4660 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4661 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4663 assert(
static_cast<int64_t>(sliceValues.size()) ==
4664 sliceVecTy.getNumElements() &&
4665 "Invalid number of slice elements");
4669OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4670 if (getSourceVectorType() == getResult().
getType())
4677 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4684void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4706class StridedSliceFolder final
4707 :
public OpRewritePattern<ExtractStridedSliceOp> {
4709 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4711 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4712 PatternRewriter &rewriter)
const override {
4713 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4717 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4720 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4721 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4722 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4723 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4725 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4726 SmallVector<int64_t> combinedOffsets(newRank, 0);
4727 SmallVector<int64_t> combinedSizes(newRank);
4728 ArrayRef<int64_t> firstSourceShape =
4729 firstOp.getSourceVectorType().getShape();
4730 for (
unsigned i = 0; i < newRank; ++i) {
4731 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4732 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4733 combinedOffsets[i] = off1 + off2;
4735 if (i < secondSizes.size()) {
4736 combinedSizes[i] = secondSizes[i];
4737 }
else if (i < firstSizes.size()) {
4738 combinedSizes[i] = firstSizes[i];
4740 combinedSizes[i] = firstSourceShape[i];
4744 SmallVector<int64_t> combinedStrides(newRank, 1);
4746 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4764class StridedSliceCreateMaskFolder final
4765 :
public OpRewritePattern<ExtractStridedSliceOp> {
4769 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4770 PatternRewriter &rewriter)
const override {
4771 Location loc = extractStridedSliceOp.getLoc();
4775 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4779 if (extractStridedSliceOp.hasNonUnitStrides())
4782 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4784 SmallVector<int64_t> sliceOffsets;
4787 SmallVector<int64_t> sliceSizes;
4791 SmallVector<Value> sliceMaskDimSizes;
4792 sliceMaskDimSizes.reserve(maskDimSizes.size());
4796 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4797 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4801 IntegerAttr offsetAttr =
4803 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4804 Value sliceMaskDimSize =
4805 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4806 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4811 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4815 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4823class StridedSliceConstantMaskFolder final
4824 :
public OpRewritePattern<ExtractStridedSliceOp> {
4828 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4829 PatternRewriter &rewriter)
const override {
4832 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4833 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4834 if (!constantMaskOp)
4837 if (extractStridedSliceOp.hasNonUnitStrides())
4840 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4842 SmallVector<int64_t> sliceOffsets;
4845 SmallVector<int64_t> sliceSizes;
4849 SmallVector<int64_t> sliceMaskDimSizes;
4850 sliceMaskDimSizes.reserve(maskDimSizes.size());
4851 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4852 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4853 int64_t sliceMaskDimSize = std::max(
4854 static_cast<int64_t
>(0),
4855 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4856 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4859 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4860 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4861 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4864 if (llvm::is_contained(sliceMaskDimSizes, 0))
4865 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4870 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4878class StridedSliceBroadcast final
4879 :
public OpRewritePattern<ExtractStridedSliceOp> {
4883 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4884 PatternRewriter &rewriter)
const override {
4890 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4891 auto dstVecType = llvm::cast<VectorType>(op.getType());
4892 unsigned dstRank = dstVecType.getRank();
4893 unsigned rankDiff = dstRank - srcRank;
4897 bool needsSlice =
false;
4898 for (
unsigned i = 0; i < srcRank; i++) {
4899 if (srcVecType.getDimSize(i) != 1 &&
4900 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4907 SmallVector<int64_t> offsets =
4909 SmallVector<int64_t> sizes =
4911 for (
unsigned i = 0; i < srcRank; i++) {
4912 if (srcVecType.getDimSize(i) == 1) {
4920 source = ExtractStridedSliceOp::create(
4921 rewriter, op->getLoc(), source, offsets, sizes,
4930class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4934 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4935 PatternRewriter &rewriter)
const override {
4937 Value splat = getScalarSplatSource(op.getSource());
4961class ContiguousExtractStridedSliceToExtract final
4962 :
public OpRewritePattern<ExtractStridedSliceOp> {
4966 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4967 PatternRewriter &rewriter)
const override {
4968 if (op.hasNonUnitStrides())
4970 Value source = op.getOperand();
4971 auto sourceType = cast<VectorType>(source.
getType());
4972 if (sourceType.isScalable() || sourceType.getRank() == 0)
4981 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4982 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4989 if (numOffsets == 0)
4994 if (numOffsets == sourceType.getRank() &&
4995 static_cast<int>(sizes.size()) == sourceType.getRank())
4999 for (
int i = 0; i < numOffsets; ++i) {
5007 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
5008 sizes[numOffsets] == 1) {
5013 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5014 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5023void ExtractStridedSliceOp::getCanonicalizationPatterns(
5024 RewritePatternSet &results, MLIRContext *context) {
5027 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5028 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5029 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5039void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5040 VectorType vectorType, Value source,
5042 AffineMapAttr permutationMapAttr,
5045 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5047 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5050 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
5051 *padding, Value(), inBoundsAttr);
5059void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5060 VectorType vectorType, Value source,
5062 AffineMap permutationMap,
5063 std::optional<ArrayRef<bool>> inBounds) {
5064 if (!permutationMap)
5066 llvm::cast<ShapedType>(source.
getType()), vectorType);
5067 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5068 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5071 SmallVector<bool>(vectorType.getRank(),
false));
5073 build(builder,
result, vectorType, source,
indices, padding,
5074 permutationMapAttr, inBoundsAttr);
5080void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5081 VectorType vectorType, Value source,
5083 std::optional<ArrayRef<bool>> inBounds) {
5085 build(builder,
result, vectorType, source,
indices, padding,
5086 AffineMap(), inBounds);
5089template <
typename EmitFun>
5093 for (
auto expr : permutationMap.
getResults()) {
5094 auto dim = dyn_cast<AffineDimExpr>(expr);
5095 auto zero = dyn_cast<AffineConstantExpr>(expr);
5097 if (zero.getValue() != 0) {
5099 "requires a projected permutation_map (at most one dim or the zero "
5100 "constant can appear in each result)");
5105 return emitOpError(
"requires a projected permutation_map (at most one "
5106 "dim or the zero constant can appear in each result)");
5108 if (seen[dim.getPosition()]) {
5110 "requires a permutation_map that is a permutation (found one dim "
5111 "used more than once)");
5113 seen[dim.getPosition()] =
true;
5120 VectorType vectorType, VectorType maskType,
5121 VectorType inferredMaskType,
AffineMap permutationMap,
5123 if (op->hasAttr(
"masked")) {
5124 return op->emitOpError(
"masked attribute has been removed. "
5125 "Use in_bounds instead.");
5128 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5129 return op->emitOpError(
5130 "requires source to be a memref or ranked tensor type");
5132 auto elementType = shapedType.getElementType();
5134 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5136 unsigned sourceVecSize =
5138 vectorElementType.getShape().back();
5139 unsigned resultVecSize =
5141 vectorType.getShape().back();
5142 if (resultVecSize % sourceVecSize != 0)
5143 return op->emitOpError(
5144 "requires the bitwidth of the minor 1-D vector to be an integral "
5145 "multiple of the bitwidth of the minor 1-D vector of the source");
5147 unsigned sourceVecEltRank = vectorElementType.getRank();
5148 unsigned resultVecRank = vectorType.getRank();
5149 if (sourceVecEltRank > resultVecRank)
5150 return op->emitOpError(
5151 "requires source vector element and vector result ranks to match.");
5152 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5155 return op->emitOpError(
"requires a permutation_map with result dims of "
5156 "the same rank as the vector type");
5159 return op->emitOpError(
"does not support masks with vector element type");
5162 unsigned minorSize =
5163 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5164 unsigned resultVecSize =
5167 return op->emitOpError(
5168 "requires the bitwidth of the minor 1-D vector to be an integral "
5169 "multiple of the bitwidth of the source element type");
5173 return op->emitOpError(
"requires a permutation_map with result dims of "
5174 "the same rank as the vector type");
5178 return op->emitOpError(
"requires permutation_map without symbols");
5180 if (permutationMap.
getNumInputs() != shapedType.getRank())
5181 return op->emitOpError(
"requires a permutation_map with input dims of the "
5182 "same rank as the source type");
5184 if (maskType && maskType != inferredMaskType)
5185 return op->emitOpError(
"inferred mask type (")
5186 << inferredMaskType <<
") and mask operand type (" << maskType
5190 return op->emitOpError(
"expects the in_bounds attr of same rank "
5191 "as permutation_map results: ")
5192 << AffineMapAttr::get(permutationMap)
5193 <<
" vs inBounds of size: " << inBounds.size();
5200 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5201 if (op.getPermutationMap().isMinorIdentity())
5202 elidedAttrs.push_back(op.getPermutationMapAttrName());
5204 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
5205 elidedAttrs.push_back(op.getInBoundsAttrName());
5209void TransferReadOp::print(OpAsmPrinter &p) {
5212 p <<
", " << getMask();
5219 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
5221 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
5226 if (maskShape.empty())
5227 maskShape.push_back(1);
5232 return VectorType::get(maskShape, i1Type, scalableDims);
5249 if (hasMask.succeeded()) {
5256 if (types.size() != 2)
5257 return parser.
emitError(typesLoc,
"requires two types");
5259 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5260 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5261 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5262 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5264 return parser.
emitError(typesLoc,
"requires vector type");
5265 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
5269 if (shapedType.getRank() <
5272 "expected a custom permutation_map when "
5273 "rank(source) != rank(destination)");
5275 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5277 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5279 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
5280 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5281 if (!inBoundsAttr) {
5282 result.addAttribute(inBoundsAttrName,
5291 if (hasMask.succeeded()) {
5292 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5294 maskInfo.
location,
"does not support masks with vector element type");
5297 "expected the same rank for the vector and the "
5298 "results of the permutation map");
5306 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5308 {1, static_cast<int32_t>(indexInfo.size()), 1,
5309 static_cast<int32_t>(hasMask.succeeded())}));
5313LogicalResult TransferReadOp::verify() {
5315 ShapedType shapedType = getShapedType();
5317 VectorType maskType = getMaskType();
5318 auto paddingType = getPadding().getType();
5319 auto permutationMap = getPermutationMap();
5320 VectorType inferredMaskType =
5323 auto sourceElementType = shapedType.getElementType();
5325 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5326 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5329 shapedType, vectorType, maskType,
5330 inferredMaskType, permutationMap, getInBounds())))
5333 if (
auto sourceVectorElementType =
5334 llvm::dyn_cast<VectorType>(sourceElementType)) {
5337 if (sourceVectorElementType != paddingType)
5339 "requires source element type and padding type to match.");
5343 if (!VectorType::isValidElementType(paddingType))
5344 return emitOpError(
"requires valid padding vector elemental type");
5347 if (paddingType != sourceElementType)
5349 "requires formal padding and source of the same elemental type");
5360Type TransferReadOp::getExpectedMaskType() {
5367VectorType TransferReadOp::getVectorType() {
5368 return cast<VectorType>(getVector().
getType());
5371template <
typename TransferOp>
5375 if (op.getShapedType().isDynamicDim(indicesIdx))
5379 if (!cstOp.has_value())
5382 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5383 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5385 return cstOp.value() + vectorSize <= sourceSize;
5388template <
typename TransferOp>
5392 if (op.getTransferRank() == 0)
5395 bool changed =
false;
5397 newInBounds.reserve(op.getTransferRank());
5402 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5404 if (op.isDimInBounds(i)) {
5405 newInBounds.push_back(
true);
5410 bool inBounds =
false;
5411 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5414 dimExpr.getPosition());
5415 nonBcastDims.push_back(i);
5418 newInBounds.push_back(inBounds);
5420 changed |= inBounds;
5426 bool allNonBcastDimsInBounds = llvm::all_of(
5427 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5428 if (allNonBcastDimsInBounds) {
5430 changed |= !newInBounds[idx];
5431 newInBounds[idx] =
true;
5439 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5443template <
typename TransferOp>
5445 auto mask = op.getMask();
5452 op.getMaskMutable().clear();
5460template <
typename TransferOp>
5462 VectorType vecType = op.getVectorType();
5463 if (vecType.getRank() != 1 || vecType.getShape()[0] != 1 ||
5464 vecType.isScalable())
5471 int64_t srcRank = op.getShapedType().getRank();
5477 op.setPermutationMapAttr(AffineMapAttr::get(minorIdentity));
5491static Value foldRAW(TransferReadOp readOp) {
5492 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5494 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5497 return defWrite.getVector();
5499 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5500 cast<VectorTransferOpInterface>(readOp.getOperation())))
5502 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5507OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5508 if (Value vec = foldRAW(*
this))
5521 return OpFoldResult();
5524std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5528void TransferReadOp::getEffects(
5529 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5531 if (llvm::isa<MemRefType>(getShapedType()))
5532 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5533 SideEffects::DefaultResource::get());
5537 if (hasPureTensorSemantics())
5544static AffineMap inverseWithUnusedDims(AffineMap map) {
5546 "expected a projected permutation map");
5551 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5581struct TransferReadAfterWriteToBroadcast
5582 :
public OpRewritePattern<TransferReadOp> {
5585 LogicalResult matchAndRewrite(TransferReadOp readOp,
5586 PatternRewriter &rewriter)
const override {
5587 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5591 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5595 if (readOp.getMask() || defWrite.getMask())
5598 if (readOp.getIndices() != defWrite.getIndices())
5601 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5605 if (readOp.getTransferChunkAccessed() !=
5606 defWrite.getTransferChunkAccessed())
5613 AffineMap readMap = readOp.getPermutationMap();
5614 AffineMap writeMap = defWrite.getPermutationMap();
5615 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5616 AffineMap composedMap = readMap.
compose(invWriteMap);
5630 int64_t numBroadcastedDims = broadcastedDims.size();
5631 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5633 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5634 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5635 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5636 invPerm[effectiveDim] = idx;
5641 VectorType readVecTy = readOp.getVectorType();
5643 auto broadcastedVecTy =
5645 readVecTy.getElementType(),
5648 Value vec = defWrite.getVector();
5649 Location loc = readOp.getLoc();
5650 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5657void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5658 MLIRContext *context) {
5659 results.
add<TransferReadAfterWriteToBroadcast>(context);
5662FailureOr<std::optional<SmallVector<Value>>>
5663TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5664 if (!hasPureBufferSemantics())
5675void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5677 AffineMapAttr permutationMapAttr,
5680 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5681 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5682 mask, inBoundsAttr);
5686void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5688 AffineMapAttr permutationMapAttr,
5690 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5691 Value(), inBoundsAttr);
5696void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5698 AffineMap permutationMap,
5699 std::optional<ArrayRef<bool>> inBounds) {
5700 if (!permutationMap)
5703 llvm::cast<VectorType>(vector.
getType()));
5704 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5706 (inBounds && !inBounds.value().empty())
5709 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5710 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5711 Value(), inBoundsAttr);
5716void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5718 std::optional<ArrayRef<bool>> inBounds) {
5723ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5724 OperationState &
result) {
5727 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5728 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5729 SmallVector<Type, 2> types;
5730 OpAsmParser::UnresolvedOperand maskInfo;
5736 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5741 if (types.size() != 2)
5742 return parser.
emitError(typesLoc,
"requires two types");
5744 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5746 return parser.
emitError(typesLoc,
"requires vector type");
5747 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5748 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5749 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5750 auto permMapAttrName =
5751 TransferWriteOp::getPermutationMapAttrName(
result.name);
5752 auto permMapAttr =
result.attributes.get(permMapAttrName);
5755 if (shapedType.getRank() <
5758 "expected a custom permutation_map when "
5759 "rank(source) != rank(destination)");
5761 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5763 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5765 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5766 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5767 if (!inBoundsAttr) {
5768 result.addAttribute(inBoundsAttrName,
5776 if (hasMask.succeeded()) {
5777 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5779 maskInfo.
location,
"does not support masks with vector element type");
5782 "expected the same rank for the vector and the "
5783 "results of the permutation map");
5789 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5791 {1, 1, static_cast<int32_t>(indexInfo.size()),
5792 static_cast<int32_t>(hasMask.succeeded())}));
5793 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5797void TransferWriteOp::print(OpAsmPrinter &p) {
5800 p <<
", " << getMask();
5805LogicalResult TransferWriteOp::verify() {
5807 ShapedType shapedType = getShapedType();
5809 VectorType maskType = getMaskType();
5810 auto permutationMap = getPermutationMap();
5811 VectorType inferredMaskType =
5815 if (llvm::size(
getIndices()) != shapedType.getRank())
5816 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5820 if (hasBroadcastDim())
5821 return emitOpError(
"should not have broadcast dimensions");
5824 shapedType, vectorType, maskType,
5825 inferredMaskType, permutationMap, getInBounds())))
5838Type TransferWriteOp::getExpectedMaskType() {
5845Value TransferWriteOp::getVector() {
return getOperand(0); }
5846VectorType TransferWriteOp::getVectorType() {
5847 return cast<VectorType>(getValueToStore().
getType());
5870static LogicalResult foldReadInitWrite(TransferWriteOp write,
5871 ArrayRef<Attribute>,
5872 SmallVectorImpl<OpFoldResult> &results) {
5874 if (write.getTransferRank() == 0)
5876 auto rankedTensorType =
5877 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5879 if (!rankedTensorType)
5882 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5886 if (read.getTransferRank() == 0)
5889 if (!read.getPermutationMap().isMinorIdentity() ||
5890 !write.getPermutationMap().isMinorIdentity())
5893 if (read.getTransferRank() != write.getTransferRank())
5896 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5899 if (read.getMask() || write.getMask())
5902 if (read.getBase().getType() != rankedTensorType)
5905 if (read.getVectorType() != write.getVectorType())
5908 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5911 auto isNotConstantZero = [](Value v) {
5913 return !cstOp.has_value() || cstOp.value() != 0;
5915 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5916 llvm::any_of(write.getIndices(), isNotConstantZero))
5919 results.push_back(read.getBase());
5923static bool checkSameValueWAR(vector::TransferReadOp read,
5924 vector::TransferWriteOp write) {
5925 return read.getBase() == write.getBase() &&
5926 read.getIndices() == write.getIndices() &&
5927 read.getPermutationMap() == write.getPermutationMap() &&
5928 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5945static LogicalResult foldWAR(TransferWriteOp write,
5946 SmallVectorImpl<OpFoldResult> &results) {
5947 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5949 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5953 if (!checkSameValueWAR(read, write))
5955 results.push_back(read.getBase());
5959LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5960 SmallVectorImpl<OpFoldResult> &results) {
5961 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5963 if (succeeded(foldWAR(*
this, results)))
5977std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5981void TransferWriteOp::getEffects(
5982 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5984 if (llvm::isa<MemRefType>(getShapedType()))
5985 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5986 SideEffects::DefaultResource::get());
5990 if (hasPureTensorSemantics())
6020class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
6023 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
6024 PatternRewriter &rewriter)
const override {
6025 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
6027 vector::TransferWriteOp writeToModify = writeOp;
6029 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
6033 writeToModify.getBaseMutable().assign(defWrite.getBase());
6038 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6039 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6043 if (!defWrite->hasOneUse())
6045 writeToModify = defWrite;
6046 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6075struct SwapExtractSliceOfTransferWrite
6076 :
public OpRewritePattern<tensor::InsertSliceOp> {
6080 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6081 PatternRewriter &rewriter)
const override {
6082 if (!insertOp.hasUnitStride())
6085 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6086 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6088 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6089 if (!transferOp || !transferOp->hasOneUse())
6094 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6096 "use-def chain is rank-reducing");
6100 if (!extractOp.hasZeroOffset()) {
6102 "ExtractSliceOp has non-zero offset");
6106 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6107 return getConstantIntValue(value) == static_cast<int64_t>(0);
6110 "TranferWriteOp has non-zero offset");
6114 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6116 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
6119 for (
auto [insertSize, extractSize] :
6120 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6123 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
6128 assert(transferOp.getVectorType().hasStaticShape() &&
6129 "expected vector to have a static shape");
6130 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
6132 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6133 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
6135 insertOp,
"TransferWriteOp may not write the full tensor.");
6140 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
6141 auto newExtractOp = tensor::ExtractSliceOp::create(
6142 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6143 insertOp.getDest(), insertOp.getMixedOffsets(),
6144 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6145 auto newTransferWriteOp = TransferWriteOp::create(
6146 rewriter, transferOp.getLoc(), transferOp.getVector(),
6147 newExtractOp.getResult(), transferOp.getIndices(),
6148 transferOp.getPermutationMapAttr(),
6151 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6159void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6160 MLIRContext *context) {
6161 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6164FailureOr<std::optional<SmallVector<Value>>>
6165TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6166 if (!hasPureBufferSemantics())
6176static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6178 MemRefType memRefTy) {
6181 if (!vecTy.isScalable() &&
6182 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6185 if (!memRefTy.isLastDimUnitStride())
6186 return op->
emitOpError(
"most minor memref dim must have unit stride");
6190LogicalResult vector::LoadOp::verify() {
6194 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
6197 if (memRefTy.getRank() < resVecTy.getRank())
6199 "destination memref has lower rank than the result vector");
6202 Type memElemTy = memRefTy.getElementType();
6203 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6204 if (memVecTy != resVecTy)
6205 return emitOpError(
"base memref and result vector types should match");
6206 memElemTy = memVecTy.getElementType();
6209 if (resVecTy.getElementType() != memElemTy)
6210 return emitOpError(
"base and result element types should match");
6211 if (llvm::size(
getIndices()) != memRefTy.getRank())
6212 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6216OpFoldResult LoadOp::fold(FoldAdaptor) {
6219 return OpFoldResult();
6222std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6226FailureOr<std::optional<SmallVector<Value>>>
6227LoadOp::bubbleDownCasts(OpBuilder &builder) {
6236LogicalResult vector::StoreOp::verify() {
6240 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
6243 if (memRefTy.getRank() < valueVecTy.getRank())
6244 return emitOpError(
"source memref has lower rank than the vector to store");
6247 Type memElemTy = memRefTy.getElementType();
6248 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6249 if (memVecTy != valueVecTy)
6251 "base memref and valueToStore vector types should match");
6252 memElemTy = memVecTy.getElementType();
6255 if (valueVecTy.getElementType() != memElemTy)
6256 return emitOpError(
"base and valueToStore element type should match");
6257 if (llvm::size(
getIndices()) != memRefTy.getRank())
6258 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6262LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6263 SmallVectorImpl<OpFoldResult> &results) {
6267std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6271FailureOr<std::optional<SmallVector<Value>>>
6272StoreOp::bubbleDownCasts(OpBuilder &builder) {
6281LogicalResult MaskedLoadOp::verify() {
6282 VectorType maskVType = getMaskVectorType();
6283 VectorType passVType = getPassThruVectorType();
6290 if (llvm::size(
getIndices()) != memType.getRank())
6291 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6292 if (resVType.getShape() != maskVType.getShape())
6293 return emitOpError(
"expected result shape to match mask shape");
6294 if (resVType != passVType)
6295 return emitOpError(
"expected pass_thru of same type as result type");
6300class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
6303 LogicalResult matchAndRewrite(MaskedLoadOp
load,
6304 PatternRewriter &rewriter)
const override {
6316 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
6321void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6322 MLIRContext *context) {
6323 results.
add<MaskedLoadFolder>(context);
6326OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6329 return OpFoldResult();
6332FailureOr<std::optional<SmallVector<Value>>>
6333MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6342LogicalResult MaskedStoreOp::verify() {
6343 VectorType maskVType = getMaskVectorType();
6350 if (llvm::size(
getIndices()) != memType.getRank())
6351 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6352 if (valueVType.getShape() != maskVType.getShape())
6353 return emitOpError(
"expected valueToStore shape to match mask shape");
6358class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6361 LogicalResult matchAndRewrite(MaskedStoreOp store,
6362 PatternRewriter &rewriter)
const override {
6366 store, store.getValueToStore(), store.getBase(), store.getIndices());
6374 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6379void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6380 MLIRContext *context) {
6381 results.
add<MaskedStoreFolder>(context);
6384LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6385 SmallVectorImpl<OpFoldResult> &results) {
6389FailureOr<std::optional<SmallVector<Value>>>
6390MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6399LogicalResult GatherOp::verify() {
6400 VectorType indVType = getIndexVectorType();
6401 VectorType maskVType = getMaskVectorType();
6403 ShapedType baseType = getBaseType();
6405 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6406 return emitOpError(
"requires base to be a memref or ranked tensor type");
6411 if (llvm::size(getOffsets()) != baseType.getRank())
6412 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6413 if (resVType.getShape() != indVType.getShape())
6414 return emitOpError(
"expected result dim to match indices dim");
6415 if (resVType.getShape() != maskVType.getShape())
6416 return emitOpError(
"expected result dim to match mask dim");
6417 if (resVType != getPassThruVectorType())
6418 return emitOpError(
"expected pass_thru of same type as result type");
6419 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6421 "alignment is only supported for memref bases, not tensor bases");
6430Type GatherOp::getExpectedMaskType() {
6431 auto vecType = this->getIndexVectorType();
6432 return VectorType::get(vecType.getShape(),
6433 IntegerType::get(vecType.getContext(), 1),
6434 vecType.getScalableDims());
6437std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6442static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6443 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6444 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6450 DenseIntElementsAttr elements;
6455 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6459class GatherFolder final :
public OpRewritePattern<GatherOp> {
6462 LogicalResult matchAndRewrite(GatherOp gather,
6463 PatternRewriter &rewriter)
const override {
6468 rewriter.
replaceOp(gather, gather.getPassThru());
6473 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6479class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6482 LogicalResult matchAndRewrite(GatherOp op,
6483 PatternRewriter &rewriter)
const override {
6484 if (!isa<MemRefType>(op.getBase().getType()))
6487 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6491 op.getOffsets(), op.getMask(),
6498void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6499 MLIRContext *context) {
6500 results.
add<GatherFolder, FoldContiguousGather>(context);
6503FailureOr<std::optional<SmallVector<Value>>>
6504GatherOp::bubbleDownCasts(OpBuilder &builder) {
6513LogicalResult ScatterOp::verify() {
6514 VectorType indVType = getIndexVectorType();
6515 VectorType maskVType = getMaskVectorType();
6517 ShapedType baseType = getBaseType();
6519 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6520 return emitOpError(
"requires base to be a memref or ranked tensor type");
6525 if (llvm::size(getOffsets()) != baseType.getRank())
6526 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6527 if (valueVType.getShape() != indVType.getShape())
6528 return emitOpError(
"expected valueToStore dim to match indices dim");
6529 if (valueVType.getShape() != maskVType.getShape())
6530 return emitOpError(
"expected valueToStore dim to match mask dim");
6531 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6533 "alignment is only supported for memref bases, not tensor bases");
6538class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6541 LogicalResult matchAndRewrite(ScatterOp scatter,
6542 PatternRewriter &rewriter)
const override {
6543 ShapedType baseType = scatter.getBaseType();
6544 bool isMemRef = isa<MemRefType>(baseType);
6545 if (!isMemRef && !isa<RankedTensorType>(baseType))
6558 rewriter.
replaceOp(scatter, scatter.getBase());
6563 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6569class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6572 LogicalResult matchAndRewrite(ScatterOp op,
6573 PatternRewriter &rewriter)
const override {
6576 if (!isa<MemRefType>(op.getBase().getType()))
6579 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6583 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6589void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6590 MLIRContext *context) {
6591 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6594FailureOr<std::optional<SmallVector<Value>>>
6595ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6604LogicalResult ExpandLoadOp::verify() {
6605 VectorType maskVType = getMaskVectorType();
6606 VectorType passVType = getPassThruVectorType();
6613 if (llvm::size(
getIndices()) != memType.getRank())
6614 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6615 if (resVType.getShape() != maskVType.getShape())
6616 return emitOpError(
"expected result shape to match mask shape");
6617 if (resVType != passVType)
6618 return emitOpError(
"expected pass_thru of same type as result type");
6623class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6626 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6627 PatternRewriter &rewriter)
const override {
6631 expand, expand.getType(), expand.getBase(), expand.getIndices());
6634 rewriter.
replaceOp(expand, expand.getPassThru());
6639 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6644void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6645 MLIRContext *context) {
6646 results.
add<ExpandLoadFolder>(context);
6649FailureOr<std::optional<SmallVector<Value>>>
6650ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6659LogicalResult CompressStoreOp::verify() {
6660 VectorType maskVType = getMaskVectorType();
6667 if (llvm::size(
getIndices()) != memType.getRank())
6668 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6669 if (valueVType.getShape() != maskVType.getShape())
6670 return emitOpError(
"expected valueToStore shape to match mask shape");
6675class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6678 LogicalResult matchAndRewrite(CompressStoreOp compress,
6679 PatternRewriter &rewriter)
const override {
6683 compress, compress.getValueToStore(), compress.getBase(),
6684 compress.getIndices());
6692 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6697void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6698 MLIRContext *context) {
6699 results.
add<CompressStoreFolder>(context);
6702FailureOr<std::optional<SmallVector<Value>>>
6703CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6712void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6714 setResultRanges(getResult(), argRanges.front());
6717std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6718 return llvm::to_vector<4>(getResultVectorType().
getShape());
6721LogicalResult ShapeCastOp::verify() {
6723 VectorType sourceType = getSourceVectorType();
6724 VectorType resultType = getResultVectorType();
6732 int64_t sourceNElms = sourceType.getNumElements();
6733 int64_t resultNElms = resultType.getNumElements();
6734 if (sourceNElms != resultNElms) {
6735 return emitOpError() <<
"has different number of elements at source ("
6736 << sourceNElms <<
") and result (" << resultNElms
6741 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6742 int64_t resultNScalableDims = resultType.getNumScalableDims();
6743 if (sourceNScalableDims != resultNScalableDims)
6744 return emitOpError() <<
"has different number of scalable dims at source ("
6745 << sourceNScalableDims <<
") and result ("
6746 << resultNScalableDims <<
")";
6755static bool isOrderPreserving(TransposeOp transpose) {
6756 ArrayRef<int64_t> permutation = transpose.getPermutation();
6757 VectorType sourceType = transpose.getSourceVectorType();
6758 ArrayRef<int64_t> inShape = sourceType.getShape();
6759 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6760 auto isNonScalableUnitDim = [&](int64_t dim) {
6761 return inShape[dim] == 1 && !inDimIsScalable[dim];
6763 int64_t current = 0;
6764 for (
auto p : permutation) {
6765 if (!isNonScalableUnitDim(p)) {
6775OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6777 VectorType resultType =
getType();
6780 if (getSource().
getType() == resultType)
6784 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6785 setOperand(precedingShapeCast.getSource());
6790 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6791 if (isOrderPreserving(transpose)) {
6792 setOperand(transpose.getVector());
6800 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6801 if (bcastOp.getSourceType() == resultType)
6802 return bcastOp.getSource();
6806 if (
auto denseAttr =
6807 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6808 return denseAttr.reshape(
getType());
6824static VectorType trimTrailingOneDims(VectorType oldType) {
6825 ArrayRef<int64_t> oldShape = oldType.getShape();
6826 ArrayRef<int64_t> newShape = oldShape;
6828 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6829 ArrayRef<bool> newScalableDims = oldScalableDims;
6831 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6832 newShape = newShape.drop_back(1);
6833 newScalableDims = newScalableDims.drop_back(1);
6838 if (newShape.empty()) {
6839 newShape = oldShape.take_back();
6840 newScalableDims = oldScalableDims.take_back();
6843 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6858class ShapeCastCreateMaskFolderTrailingOneDim final
6859 :
public OpRewritePattern<ShapeCastOp> {
6863 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6864 PatternRewriter &rewriter)
const override {
6865 Value shapeOpSrc = shapeOp->getOperand(0);
6866 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6867 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6868 if (!createMaskOp && !constantMaskOp)
6871 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6872 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6874 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6875 if (newVecType != shapeOpResTy)
6878 auto numDimsToDrop =
6879 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6886 auto maskOperands = createMaskOp.getOperands();
6887 auto numMaskOperands = maskOperands.size();
6890 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6892 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6893 if (!constant || (constant.value() != 1))
6896 SmallVector<Value> newMaskOperands =
6897 maskOperands.drop_back(numDimsToDrop);
6904 if (constantMaskOp) {
6905 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6906 auto numMaskOperands = maskDimSizes.size();
6909 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6911 if (maskDimSizes[i] != 1)
6915 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6928int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6929 ArrayRef<int64_t> dstShape) {
6930 int stretchingFactor = 1;
6931 int numLeadingDims = dstShape.size() - srcShape.size();
6932 for (
int i = 0, e = srcShape.size(); i < e; i++) {
6933 int64_t dstDim = dstShape[numLeadingDims + i];
6934 if (srcShape[i] == 1 && dstDim != 1) {
6935 stretchingFactor *= dstDim;
6938 return stretchingFactor;
6942class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6946 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6947 PatternRewriter &rewriter)
const override {
6949 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6953 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6954 bool srcIsScalar = !srcVectorType;
6962 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6963 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6964 ArrayRef<int64_t> srcShape =
6965 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6966 ArrayRef<int64_t> broadcastShape =
6967 broadcastOp.getResultVectorType().getShape();
6971 BroadcastableToResult::Success) {
6979 if (srcVectorType.getNumElements() != 1) {
6980 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6981 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6988 broadcastOp.getSource());
7007class FoldShapeCastOfFromElements final :
public OpRewritePattern<ShapeCastOp> {
7011 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
7012 PatternRewriter &rewriter)
const override {
7013 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
7018 shapeCastOp, shapeCastOp.getResultVectorType(),
7019 fromElements.getElements());
7026void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
7027 MLIRContext *context) {
7028 results.
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
7029 FoldShapeCastOfFromElements>(context);
7036LogicalResult BitCastOp::verify() {
7037 auto sourceVectorType = getSourceVectorType();
7038 auto resultVectorType = getResultVectorType();
7040 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7041 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7042 return emitOpError(
"dimension size mismatch at: ") << i;
7045 DataLayout dataLayout = DataLayout::closest(*
this);
7046 auto sourceElementBits =
7048 auto resultElementBits =
7051 if (sourceVectorType.getRank() == 0) {
7052 if (sourceElementBits != resultElementBits)
7053 return emitOpError(
"source/result bitwidth of the 0-D vector element "
7054 "types must be equal");
7055 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
7056 resultElementBits * resultVectorType.getShape().back()) {
7058 "source/result bitwidth of the minor 1-D vectors must be equal");
7064OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7070 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7071 if (getResult().
getType() == otherOp.getSource().getType())
7072 return otherOp.getSource();
7074 setOperand(otherOp.getSource());
7078 Attribute sourceConstant = adaptor.getSource();
7079 if (!sourceConstant)
7082 Type srcElemType = getSourceVectorType().getElementType();
7083 Type dstElemType = getResultVectorType().getElementType();
7085 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7086 if (floatPack.isSplat()) {
7087 auto splat = floatPack.getSplatValue<FloatAttr>();
7090 if (srcElemType.
isF16() && dstElemType.
isF32()) {
7091 uint32_t bits =
static_cast<uint32_t
>(
7092 splat.getValue().bitcastToAPInt().getZExtValue());
7094 bits = (bits << 16) | (bits & 0xffff);
7095 APInt intBits(32, bits);
7096 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7102 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7103 if (intPack.isSplat()) {
7104 auto splat = intPack.getSplatValue<IntegerAttr>();
7106 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.
isIntOrFloat()) {
7111 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7112 APInt intBits = splat.getValue().zext(dstBitWidth);
7115 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7116 intBits = (intBits << srcBitWidth) | intBits;
7126std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
7127 return llvm::to_vector<4>(getResultVectorType().
getShape());
7134static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7135 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7136 SmallVector<int64_t, 8> res(memRefType.getShape());
7138 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7144void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
7146 result.addOperands(source);
7147 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
7148 VectorType vectorType =
7149 VectorType::get(extractShape(memRefType),
7151 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7152 memRefType.getMemorySpace()));
7155LogicalResult TypeCastOp::verify() {
7156 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
7157 if (!canonicalType.getLayout().isIdentity())
7158 return emitOpError(
"expects operand to be a memref with identity layout");
7159 if (!getResultMemRefType().getLayout().isIdentity())
7160 return emitOpError(
"expects result to be a memref with identity layout");
7161 if (getResultMemRefType().getMemorySpace() !=
7163 return emitOpError(
"expects result in same memory space");
7166 auto resultType = getResultMemRefType();
7170 "expects result and operand with same underlying scalar type: ")
7172 if (extractShape(sourceType) != extractShape(resultType))
7174 "expects concatenated result and operand shapes to be equal: ")
7183void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
7184 Value vector, ArrayRef<int64_t> permutation) {
7185 VectorType vt = llvm::cast<VectorType>(vector.
getType());
7186 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7187 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7188 for (
unsigned i = 0; i < permutation.size(); ++i) {
7189 transposedShape[i] = vt.getShape()[permutation[i]];
7190 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7193 result.addOperands(vector);
7194 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7195 transposedScalableDims));
7196 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
7200OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7203 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7204 return splat.reshape(getResultVectorType());
7221 if (getSourceVectorType() == getResultVectorType() &&
7222 isOrderPreserving(*
this))
7228LogicalResult vector::TransposeOp::verify() {
7229 VectorType vectorType = getSourceVectorType();
7230 VectorType resultType = getResultVectorType();
7231 int64_t rank = resultType.getRank();
7232 if (vectorType.getRank() != rank)
7233 return emitOpError(
"vector result rank mismatch: ") << rank;
7235 ArrayRef<int64_t> perm = getPermutation();
7236 int64_t size = perm.size();
7238 return emitOpError(
"transposition length mismatch: ") << size;
7239 SmallVector<bool, 8> seen(rank,
false);
7240 for (
const auto &ta : llvm::enumerate(perm)) {
7241 if (ta.value() < 0 || ta.value() >= rank)
7242 return emitOpError(
"transposition index out of range: ") << ta.value();
7243 if (seen[ta.value()])
7244 return emitOpError(
"duplicate position index: ") << ta.value();
7245 seen[ta.value()] =
true;
7246 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7247 return emitOpError(
"dimension size mismatch at: ") << ta.value();
7252std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7253 return llvm::to_vector<4>(getResultVectorType().
getShape());
7256void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7258 setResultRanges(getResult(), argRanges.front());
7264class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
7268 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7269 PatternRewriter &rewriter)
const override {
7271 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7272 ArrayRef<int64_t> permutation2) {
7273 SmallVector<int64_t, 4>
result;
7274 for (
auto index : permutation2)
7275 result.push_back(permutation1[index]);
7280 vector::TransposeOp parentTransposeOp =
7281 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7282 if (!parentTransposeOp)
7285 SmallVector<int64_t, 4> permutation = composePermutations(
7286 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7289 transposeOp, transposeOp.getResult().
getType(),
7290 parentTransposeOp.getVector(), permutation);
7296class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
7300 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7301 PatternRewriter &rewriter)
const override {
7302 Value splat = getScalarSplatSource(transposeOp.getVector());
7307 transposeOp, transposeOp.getResultVectorType(), splat);
7313class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
7317 LogicalResult matchAndRewrite(TransposeOp transpOp,
7318 PatternRewriter &rewriter)
const override {
7319 Value transposeSrc = transpOp.getVector();
7320 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
7321 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
7322 if (!createMaskOp && !constantMaskOp)
7327 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7330 auto maskOperands = createMaskOp.getOperands();
7331 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7335 transpOp, transpOp.getResultVectorType(), newOperands);
7340 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7344 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7350class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
7354 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7355 PatternRewriter &rewriter)
const override {
7357 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7360 if (!isOrderPreserving(transposeOp))
7363 VectorType resultType = transposeOp.getType();
7370 shapeCastOp.getSource());
7389class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
7392 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7393 PatternRewriter &rewriter)
const override {
7394 auto fromElementsOp =
7395 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7396 if (!fromElementsOp)
7399 VectorType srcTy = fromElementsOp.getDest().getType();
7400 VectorType dstTy = transposeOp.getType();
7402 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7403 int64_t rank = srcTy.getRank();
7406 SmallVector<int64_t> inversePerm(rank, 0);
7407 for (int64_t i = 0; i < rank; ++i)
7408 inversePerm[permutation[i]] = i;
7410 ArrayRef<int64_t> srcShape = srcTy.getShape();
7411 ArrayRef<int64_t> dstShape = dstTy.getShape();
7412 SmallVector<int64_t> srcIdx(rank, 0);
7413 SmallVector<int64_t> dstIdx(rank, 0);
7417 auto elementsOld = fromElementsOp.getElements();
7418 SmallVector<Value> elementsNew;
7419 int64_t dstNumElements = dstTy.getNumElements();
7420 elementsNew.reserve(dstNumElements);
7424 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7428 for (int64_t j = 0; j < rank; ++j)
7429 srcIdx[j] = dstIdx[inversePerm[j]];
7431 int64_t srcLin =
linearize(srcIdx, srcStrides);
7433 elementsNew.push_back(elementsOld[srcLin]);
7467class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7470 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7471 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7473 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7474 PatternRewriter &rewriter)
const override {
7480 "not preceded by a broadcast");
7483 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7484 VectorType outputType = transpose.getResultVectorType();
7487 bool inputIsScalar = !inputType;
7488 if (inputIsScalar) {
7494 ArrayRef<int64_t> permutation = transpose.getPermutation();
7495 ArrayRef<int64_t> inputShape = inputType.getShape();
7496 int64_t inputRank = inputType.getRank();
7497 int64_t outputRank = transpose.getType().getRank();
7498 int64_t deltaRank = outputRank - inputRank;
7501 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7502 bool notOne = inputShape[inputIndex] != 1;
7503 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7504 bool groupEndFound = notOne || prevNotOne;
7505 if (groupEndFound) {
7506 int high = inputIndex + deltaRank;
7510 for (
int i = low; i < high; ++i) {
7511 if (permutation[i] < low || permutation[i] >= high) {
7513 transpose,
"permutation not local to group");
7527 vector::BroadcastableToResult::Success &&
7528 "not broadcastable directly to transpose output");
7539void vector::TransposeOp::getCanonicalizationPatterns(
7540 RewritePatternSet &results, MLIRContext *context) {
7541 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7542 FoldTransposeSplat, FoldTransposeFromElements,
7543 FoldTransposeBroadcast>(context);
7550void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7552 assert(kind == ConstantMaskKind::AllTrue ||
7553 kind == ConstantMaskKind::AllFalse);
7554 build(builder,
result, type,
7555 kind == ConstantMaskKind::AllTrue
7557 : SmallVector<int64_t>(type.getRank(), 0));
7560LogicalResult ConstantMaskOp::verify() {
7561 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7563 if (resultType.getRank() == 0) {
7564 if (getMaskDimSizes().size() != 1)
7565 return emitError(
"array attr must have length 1 for 0-D vectors");
7566 auto dim = getMaskDimSizes()[0];
7567 if (dim != 0 && dim != 1)
7568 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7573 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7575 "must specify array attr of size equal vector result rank");
7578 auto resultShape = resultType.getShape();
7579 auto resultScalableDims = resultType.getScalableDims();
7580 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7581 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7582 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7584 "array attr of size out of bounds of vector result dimension size");
7585 if (resultScalableDims[index] && maskDimSize != 0 &&
7586 maskDimSize != resultShape[index])
7588 "only supports 'none set' or 'all set' scalable dimensions");
7592 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7593 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7594 if (anyZeros && !allZeros)
7595 return emitOpError(
"expected all mask dim sizes to be zeros, "
7596 "as a result of conjunction with zero mask dim");
7600bool ConstantMaskOp::isAllOnesMask() {
7603 if (resultType.getRank() == 0) {
7604 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7605 return getMaskDimSizes()[0] == 1;
7607 for (
const auto [resultSize, maskDimSize] :
7608 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7609 if (maskDimSize < resultSize)
7615OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7616 ArrayRef<int64_t> bounds = getMaskDimSizes();
7619 auto createBoolSplat = [&](
bool x) {
7625 if (vectorSizes.empty()) {
7626 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7627 return createBoolSplat(bounds[0] == 1);
7630 if (bounds == vectorSizes)
7631 return createBoolSplat(
true);
7632 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7633 return createBoolSplat(
false);
7634 return OpFoldResult();
7641void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7643 ArrayRef<OpFoldResult> mixedOperands) {
7644 SmallVector<Value> operands =
7646 build(builder,
result, type, operands);
7649LogicalResult CreateMaskOp::verify() {
7650 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7652 if (vectorType.getRank() == 0) {
7653 if (getNumOperands() != 1)
7655 "must specify exactly one operand for 0-D create_mask");
7656 }
else if (getNumOperands() !=
7657 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7659 "must specify an operand for each result vector dimension");
7689class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7693 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7694 PatternRewriter &rewriter)
const override {
7695 VectorType maskType = createMaskOp.getVectorType();
7696 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7697 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7700 constexpr std::array<int64_t, 1> rankZeroShape{1};
7701 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7702 if (maskType.getRank() == 0) {
7703 maskTypeDimSizes = rankZeroShape;
7704 maskTypeDimScalableFlags = rankZeroScalableDims;
7709 SmallVector<int64_t, 4> constantDims;
7710 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7715 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7717 constantDims.push_back(*intSize);
7721 if (vscaleMultiplier < maskTypeDimSizes[i])
7723 constantDims.push_back(*vscaleMultiplier);
7730 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7731 value = std::clamp<int64_t>(value, 0, maskDimSize);
7734 if (llvm::is_contained(constantDims, 0))
7735 constantDims.assign(constantDims.size(), 0);
7746void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7747 MLIRContext *context) {
7748 results.
add<CreateMaskFolder>(context);
7756 OpBuilder &builder, OperationState &
result, Value mask,
7757 Operation *maskableOp,
7758 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7759 assert(maskRegionBuilder &&
7760 "builder callback for 'maskRegion' must be present");
7762 result.addOperands(mask);
7763 OpBuilder::InsertionGuard guard(builder);
7764 Region *maskRegion =
result.addRegion();
7766 maskRegionBuilder(builder, maskableOp);
7771 Value mask, Operation *maskableOp,
7772 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7773 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7779 Value mask, Value passthru, Operation *maskableOp,
7780 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7781 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7783 result.addOperands(passthru);
7784 result.addTypes(resultTypes);
7787ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7789 result.regions.reserve(1);
7790 Region &maskRegion = *
result.addRegion();
7795 OpAsmParser::UnresolvedOperand mask;
7800 OpAsmParser::UnresolvedOperand passthru;
7802 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7809 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7820 SmallVector<Type> resultTypes;
7823 result.types.append(resultTypes);
7829 if (parsePassthru.succeeded()) {
7830 if (resultTypes.empty())
7833 "expects a result if passthru operand is provided");
7842void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7843 p <<
" " << getMask();
7845 p <<
", " << getPassthru();
7849 Block *singleBlock = &getMaskRegion().getBlocks().front();
7856 p <<
" : " << getMask().getType();
7857 if (getNumResults() > 0)
7858 p <<
" -> " << getResultTypes();
7861void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7864 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7865 MaskOp>::ensureTerminator(region, builder, loc);
7871 if (isa<vector::YieldOp>(block.
back()))
7879 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7880 MaskOp>::ensureTerminator(region, builder, loc);
7886 Operation *maskedOp = &block.
front();
7887 opBuilder.setInsertionPointToEnd(&block);
7888 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7891LogicalResult MaskOp::verify() {
7893 Block &block = getMaskRegion().getBlocks().
front();
7895 return emitOpError(
"expects a terminator within the mask region");
7898 if (numMaskRegionOps > 2)
7899 return emitOpError(
"expects only one operation to mask");
7902 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7904 return emitOpError(
"expects a terminator within the mask region");
7906 if (terminator->getNumOperands() != getNumResults())
7908 "expects number of results to match mask region yielded values");
7911 if (numMaskRegionOps == 1)
7914 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7916 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7920 return emitOpError(
"expects number of results to match maskable operation "
7921 "number of results");
7923 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7924 return emitOpError(
"expects all the results from the MaskableOpInterface "
7925 "to match all the values returned by the terminator");
7927 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7929 "expects result type to match maskable operation result type");
7932 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7933 return emitOpError(
"multiple vector results not supported");
7936 Type expectedMaskType = maskableOp.getExpectedMaskType();
7937 if (getMask().
getType() != expectedMaskType)
7939 << expectedMaskType <<
" mask for the maskable operation";
7942 Value passthru = getPassthru();
7944 if (!maskableOp.supportsPassthru())
7946 "doesn't expect a passthru argument for this maskable operation");
7949 return emitOpError(
"expects result when passthru argument is provided");
7952 return emitOpError(
"expects passthru type to match result type");
7972static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7973 SmallVectorImpl<OpFoldResult> &results) {
7974 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7977 Block *block = maskOp.getMaskBlock();
7978 auto terminator = cast<vector::YieldOp>(block->
front());
7979 if (terminator.getNumOperands() == 0)
7983 llvm::append_range(results, terminator.getOperands());
7987LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7988 SmallVectorImpl<OpFoldResult> &results) {
7989 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7999 Operation *maskableOp = getMaskableOp();
8005 llvm::append_range(results, maskableOp->
getResults());
8021class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
8024 LogicalResult matchAndRewrite(MaskOp maskOp,
8025 PatternRewriter &rewriter)
const override {
8026 if (!maskOp.isEmpty())
8029 if (!maskOp.hasPassthru())
8036 VectorType maskType = maskOp.getMask().getType();
8037 for (Type resultType : maskOp.getResultTypes()) {
8038 auto vecResultType = dyn_cast<VectorType>(resultType);
8039 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8043 Block *block = maskOp.getMaskBlock();
8044 auto terminator = cast<vector::YieldOp>(block->
front());
8045 assert(terminator.getNumOperands() == 1 &&
8046 "expected one result when passthru is provided");
8049 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8050 terminator.getOperand(0), maskOp.getPassthru());
8056void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8057 MLIRContext *context) {
8058 results.
add<CanonializeEmptyMaskOp>(context);
8064Operation *MaskOp::getMaskableOp() {
8065 Block *block = getMaskBlock();
8069 return &block->
front();
8073bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
8079LogicalResult ScanOp::verify() {
8080 VectorType srcType = getSourceType();
8081 VectorType initialType = getInitialValueType();
8083 int64_t srcRank = srcType.getRank();
8084 int64_t reductionDim = getReductionDim();
8085 if (reductionDim >= srcRank)
8087 << reductionDim <<
" has to be less than " << srcRank;
8090 int64_t initialValueRank = initialType.getRank();
8091 if (initialValueRank != srcRank - 1)
8093 << initialValueRank <<
" has to be equal to " << srcRank - 1;
8096 ArrayRef<int64_t> srcShape = srcType.getShape();
8097 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8098 SmallVector<int64_t> expectedShape;
8099 for (
int i = 0; i < srcRank; i++) {
8100 if (i != reductionDim)
8101 expectedShape.push_back(srcShape[i]);
8103 if (!llvm::equal(initialValueShapes, expectedShape)) {
8104 return emitOpError(
"incompatible input/initial value shapes");
8108 Type eltType = getDestType().getElementType();
8111 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
8118 RewritePatternSet &patterns, PatternBenefit benefit) {
8120 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8121 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8122 StridedSliceConstantMaskFolder, TransposeFolder>(
8127 CombiningKind kind, Value v1, Value acc,
8128 arith::FastMathFlagsAttr fastmath,
8135 case CombiningKind::ADD:
8137 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
8138 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8139 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8141 llvm_unreachable(
"invalid value types for ADD reduction");
8143 case CombiningKind::AND:
8145 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
8147 case CombiningKind::MAXNUMF:
8148 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8149 "expected float values");
8150 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8152 case CombiningKind::MAXIMUMF:
8153 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8154 "expected float values");
8155 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8157 case CombiningKind::MINNUMF:
8158 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8159 "expected float values");
8160 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8162 case CombiningKind::MINIMUMF:
8163 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8164 "expected float values");
8165 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8167 case CombiningKind::MAXSI:
8169 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8171 case CombiningKind::MINSI:
8173 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8175 case CombiningKind::MAXUI:
8177 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8179 case CombiningKind::MINUI:
8181 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8183 case CombiningKind::MUL:
8185 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
8186 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8187 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8189 llvm_unreachable(
"invalid value types for MUL reduction");
8191 case CombiningKind::OR:
8193 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
8195 case CombiningKind::XOR:
8197 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8201 assert(
result &&
"unknown CombiningKind");
8209void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8211 auto resultType = cast<VectorType>(
getType());
8212 if (resultType.isScalable()) {
8216 APInt zero(bitwidth, 0);
8217 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8218 ConstantIntRanges
result = {zero, high, zero, high};
8219 setResultRanges(getResult(),
result);
8249struct StepCompareFolder :
public OpRewritePattern<StepOp> {
8252 LogicalResult matchAndRewrite(StepOp stepOp,
8253 PatternRewriter &rewriter)
const override {
8254 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8256 for (OpOperand &use : stepOp.getResult().getUses()) {
8257 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8262 const unsigned stepOperandNumber = use.getOperandNumber();
8263 if (stepOperandNumber != 0)
8267 unsigned constOperandNumber = 1;
8268 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8269 std::optional<int64_t> maybeConstValue =
8271 if (!maybeConstValue.has_value())
8274 int64_t constValue = maybeConstValue.value();
8275 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8277 auto maybeSplat = [&]() -> std::optional<bool> {
8279 if ((pred == arith::CmpIPredicate::ult ||
8280 pred == arith::CmpIPredicate::uge) &&
8281 stepSize <= constValue)
8282 return pred == arith::CmpIPredicate::ult;
8285 if ((pred == arith::CmpIPredicate::ule ||
8286 pred == arith::CmpIPredicate::ugt) &&
8287 stepSize - 1 <= constValue) {
8288 return pred == arith::CmpIPredicate::ule;
8292 if ((pred == arith::CmpIPredicate::eq ||
8293 pred == arith::CmpIPredicate::ne) &&
8294 stepSize <= constValue)
8295 return pred == arith::CmpIPredicate::ne;
8297 return std::nullopt;
8300 if (!maybeSplat.has_value())
8305 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8310 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8322void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8323 MLIRContext *context) {
8324 results.
add<StepCompareFolder>(context);
8334 Operation *maskableOp) {
8335 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
8347 Operation *maskableOp, Value mask,
8352 return MaskOp::create(builder, maskableOp->
getLoc(),
8355 return MaskOp::create(builder, maskableOp->
getLoc(),
8368 Value newValue, Value passthru) {
8372 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
8373 mask, newValue, passthru);
8384struct InterleaveDeinterleaveFolder :
public OpRewritePattern<InterleaveOp> {
8387 LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
8388 PatternRewriter &rewriter)
const override {
8389 auto lhsDefOp = interleaveOp.getLhs().getDefiningOp<DeinterleaveOp>();
8390 auto rhsDefOp = interleaveOp.getRhs().getDefiningOp<DeinterleaveOp>();
8391 if (!lhsDefOp || !rhsDefOp || lhsDefOp != rhsDefOp)
8393 for (
auto [idx, operand] : llvm::enumerate(interleaveOp.getOperands())) {
8394 if (cast<OpResult>(operand).getResultNumber() != idx)
8397 rewriter.
replaceOp(interleaveOp, lhsDefOp.getSource());
8403void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
8404 MLIRContext *context) {
8405 results.
add<InterleaveDeinterleaveFolder>(context);
8408std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
8409 return llvm::to_vector<4>(getResultVectorType().
getShape());
8416std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
8417 return llvm::to_vector<4>(getResultVectorType().
getShape());
8424#define GET_ATTRDEF_CLASSES
8425#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8427#define GET_OP_CLASSES
8428#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static LogicalResult foldSize1TransferPermutationMap(TransferOp op)
When the vector type is vector<1xT>, the permutation map is irrelevant: the single vector lane always...
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)