43#include "llvm/ADT/ArrayRef.h"
44#include "llvm/ADT/Repeated.h"
45#include "llvm/ADT/STLExtras.h"
46#include "llvm/ADT/SmallVector.h"
47#include "llvm/ADT/SmallVectorExtras.h"
48#include "llvm/ADT/StringSet.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Casting.h"
56#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
58#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
79 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
81 for (
bool b : denseElts.getValues<
bool>())
84 else if (!
b && val <= 0)
98 auto shape = m.getType().getShape();
100 bool allFalse =
true;
101 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
102 if (maskIdx < dimSize)
115 auto maskOperands = m.getOperands();
116 for (
Value operand : maskOperands) {
117 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
119 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
132 vector::YieldOp::create(builder, loc);
138 switch (combiningKind) {
139 case CombiningKind::ADD:
140 case CombiningKind::MUL:
142 case CombiningKind::MINUI:
143 case CombiningKind::MINSI:
144 case CombiningKind::MAXUI:
145 case CombiningKind::MAXSI:
146 case CombiningKind::AND:
147 case CombiningKind::OR:
148 case CombiningKind::XOR:
150 case CombiningKind::MINNUMF:
151 case CombiningKind::MAXNUMF:
152 case CombiningKind::MINIMUMF:
153 case CombiningKind::MAXIMUMF:
154 return llvm::isa<FloatType>(elementType);
184 VectorType vectorType) {
185 unsigned elementVectorRank = 0;
186 VectorType elementVectorType =
187 llvm::dyn_cast<VectorType>(shapedType.getElementType());
188 if (elementVectorType)
189 elementVectorRank += elementVectorType.getRank();
190 return vectorType.getRank() - elementVectorRank;
194 VectorType vectorType) {
197 if (shapedType.getRank() == 0 &&
203 shapedType.getRank(),
205 shapedType.getContext());
212 vector::TransferReadOp read) {
213 auto readMask = read.getMask();
214 auto writeMask = write.getMask();
220 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
221 if (!couldBeSameSplat)
238 vector::TransferReadOp read) {
239 return !defWrite.hasOutOfBoundsDim() &&
240 defWrite.getIndices() == read.getIndices() &&
241 defWrite.getVectorType() == read.getVectorType() &&
242 defWrite.getPermutationMap() == read.getPermutationMap() &&
243 ((!defWrite.getMask() && !read.getMask()) ||
248 vector::TransferWriteOp priorWrite) {
249 return priorWrite.getIndices() == write.getIndices() &&
250 priorWrite.getMask() == write.getMask() &&
251 priorWrite.getVectorType() == write.getVectorType() &&
252 priorWrite.getPermutationMap() == write.getPermutationMap();
256 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
257 bool testDynamicValueUsingBounds) {
259 if (transferA.getVectorType() != transferB.getVectorType())
261 unsigned rankOffset = transferA.getLeadingShapedRank();
262 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
263 Value indexA = transferA.getIndices()[i];
264 Value indexB = transferB.getIndices()[i];
268 if (i < rankOffset) {
271 if (cstIndexA.has_value() && cstIndexB.has_value()) {
272 if (*cstIndexA != *cstIndexB)
276 if (testDynamicValueUsingBounds) {
279 FailureOr<uint64_t> delta =
281 if (succeeded(delta) && *delta != 0)
284 FailureOr<bool> testEqual =
286 if (succeeded(testEqual) && !testEqual.value())
292 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
293 if (cstIndexA.has_value() && cstIndexB.has_value()) {
294 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
295 if (distance >= vectorDim)
299 if (testDynamicValueUsingBounds) {
302 FailureOr<int64_t> delta =
304 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
307 FailureOr<int64_t> computeDelta =
309 if (succeeded(computeDelta)) {
310 if (std::abs(computeDelta.value()) >= vectorDim)
320 VectorTransferOpInterface transferB,
321 bool testDynamicValueUsingBounds) {
322 if (transferA.getBase() != transferB.getBase())
325 testDynamicValueUsingBounds);
335 for (
auto [posInDim, dimSize, offsetInDim] :
336 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
338 if (posInDim < dimSize + offsetInDim)
342 posInDim = offsetInDim;
352 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
354 assert(constOp &&
"Unexpected non-constant index");
355 return constOp.value();
365 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
366 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
367 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
377 llvm::transform(foldResults, std::back_inserter(values),
379 if (
auto attr = dyn_cast<Attribute>(foldResult))
381 builder, loc, cast<IntegerAttr>(attr).getInt())
384 return cast<Value>(foldResult);
397 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
399 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
409 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
410 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
411 if (intAttr.getType() != expectedType)
412 return IntegerAttr::get(expectedType, intAttr.getInt());
418 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
419 auto intType = dyn_cast<IntegerType>(expectedType);
423 APFloat floatVal = floatAttr.getValue();
424 APInt intVal = floatVal.bitcastToAPInt();
425 return IntegerAttr::get(expectedType, intVal);
464struct VectorInlinerInterface :
public DialectInlinerInterface {
465 using DialectInlinerInterface::DialectInlinerInterface;
474void VectorDialect::initialize() {
476#define GET_ATTRDEF_LIST
477#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
482#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
485 addInterfaces<VectorInlinerInterface>();
487 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
488 MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
490 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
491 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
493 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
495 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
496 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
497 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
508 return arith::ConstantOp::materialize(builder, value, type, loc);
524void vector::MultiDimReductionOp::build(
OpBuilder &builder,
527 CombiningKind kind) {
529 for (
const auto &en : llvm::enumerate(reductionMask))
531 reductionDims.push_back(en.index());
532 build(builder,
result, kind, source,
acc, reductionDims);
535OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
537 if (getReductionDims().empty())
542std::optional<SmallVector<int64_t, 4>>
543MultiDimReductionOp::getShapeForUnroll() {
544 return llvm::to_vector<4>(getSourceVectorType().
getShape());
547LogicalResult MultiDimReductionOp::verify() {
550 Type inferredReturnType;
551 auto sourceScalableDims = getSourceVectorType().getScalableDims();
552 for (
auto [dimIdx, dimSize] :
553 llvm::enumerate(getSourceVectorType().
getShape()))
554 if (!llvm::any_of(getReductionDims(),
555 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
556 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
558 targetShape.push_back(dimSize);
559 scalableDims.push_back(sourceScalableDims[dimIdx]);
562 if (targetShape.empty())
563 inferredReturnType = getSourceVectorType().getElementType();
565 inferredReturnType = VectorType::get(
566 targetShape, getSourceVectorType().
getElementType(), scalableDims);
567 if (
getType() != inferredReturnType)
569 <<
" is incompatible with source type "
570 << getSourceVectorType();
576Type MultiDimReductionOp::getExpectedMaskType() {
577 auto vecType = getSourceVectorType();
578 return VectorType::get(vecType.getShape(),
579 IntegerType::get(vecType.getContext(), 1),
580 vecType.getScalableDims());
589struct ElideUnitDimsInMultiDimReduction
593 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
594 PatternRewriter &rewriter)
const override {
595 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
596 for (
const auto &dim :
enumerate(shape)) {
597 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
602 OpBuilder::InsertionGuard guard(rewriter);
605 if (reductionOp.isMasked()) {
607 rootOp = reductionOp.getMaskingOp();
608 mask = reductionOp.getMaskingOp().getMask();
610 rootOp = reductionOp;
613 Location loc = reductionOp.getLoc();
614 Value acc = reductionOp.getAcc();
616 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
618 VectorType newMaskType =
619 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
620 dstVecType.getScalableDims());
621 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
623 cast = vector::ShapeCastOp::create(
624 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
629 mask = vector::ExtractOp::create(rewriter, loc, mask);
630 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
635 cast,
nullptr, mask);
642void MultiDimReductionOp::getCanonicalizationPatterns(
644 results.
add<ElideUnitDimsInMultiDimReduction>(context);
653 arith::FastMathFlags fastMathFlags) {
659 arith::FastMathFlags fastMathFlags) {
661 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
665LogicalResult ReductionOp::verify() {
667 int64_t rank = getSourceVectorType().getRank();
669 return emitOpError(
"unsupported reduction rank: ") << rank;
672 Type eltType = getDest().getType();
675 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
684Type ReductionOp::getExpectedMaskType() {
685 auto vecType = getSourceVectorType();
686 return VectorType::get(vecType.getShape(),
687 IntegerType::get(vecType.getContext(), 1),
688 vecType.getScalableDims());
695 case arith::AtomicRMWKind::addf:
696 case arith::AtomicRMWKind::addi:
697 return vector::ReductionOp::create(builder,
vector.getLoc(),
698 CombiningKind::ADD,
vector);
699 case arith::AtomicRMWKind::mulf:
700 case arith::AtomicRMWKind::muli:
701 return vector::ReductionOp::create(builder,
vector.getLoc(),
702 CombiningKind::MUL,
vector);
703 case arith::AtomicRMWKind::minimumf:
704 return vector::ReductionOp::create(builder,
vector.getLoc(),
705 CombiningKind::MINIMUMF,
vector);
706 case arith::AtomicRMWKind::mins:
707 return vector::ReductionOp::create(builder,
vector.getLoc(),
708 CombiningKind::MINSI,
vector);
709 case arith::AtomicRMWKind::minu:
710 return vector::ReductionOp::create(builder,
vector.getLoc(),
711 CombiningKind::MINUI,
vector);
712 case arith::AtomicRMWKind::maximumf:
713 return vector::ReductionOp::create(builder,
vector.getLoc(),
714 CombiningKind::MAXIMUMF,
vector);
715 case arith::AtomicRMWKind::maxs:
716 return vector::ReductionOp::create(builder,
vector.getLoc(),
717 CombiningKind::MAXSI,
vector);
718 case arith::AtomicRMWKind::maxu:
719 return vector::ReductionOp::create(builder,
vector.getLoc(),
720 CombiningKind::MAXUI,
vector);
721 case arith::AtomicRMWKind::andi:
722 return vector::ReductionOp::create(builder,
vector.getLoc(),
723 CombiningKind::AND,
vector);
724 case arith::AtomicRMWKind::ori:
725 return vector::ReductionOp::create(builder,
vector.getLoc(),
726 CombiningKind::OR,
vector);
727 case arith::AtomicRMWKind::minnumf:
728 return vector::ReductionOp::create(builder,
vector.getLoc(),
729 CombiningKind::MINNUMF,
vector);
730 case arith::AtomicRMWKind::maxnumf:
731 return vector::ReductionOp::create(builder,
vector.getLoc(),
732 CombiningKind::MAXNUMF,
vector);
733 case arith::AtomicRMWKind::xori:
734 return vector::ReductionOp::create(builder,
vector.getLoc(),
735 CombiningKind::XOR,
vector);
743std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
744 return llvm::to_vector<4>(getSourceVectorType().
getShape());
751 LogicalResult matchAndRewrite(ReductionOp reductionOp,
756 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
759 if (maskableOp.isMasked()) {
761 rootOp = maskableOp.getMaskingOp();
762 mask = maskableOp.getMaskingOp().getMask();
764 rootOp = reductionOp;
767 auto vectorType = reductionOp.getSourceVectorType();
768 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
771 Location loc = reductionOp.getLoc();
773 mask = ExtractOp::create(rewriter, loc, mask);
774 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
776 if (Value acc = reductionOp.getAcc())
779 reductionOp.getFastmathAttr(), mask);
789 results.
add<ElideSingleElementReduction>(context);
803 getIndexingMapsAttrName(
result.name),
807 getIteratorTypesAttrName(
result.name),
810 return IteratorTypeAttr::get(builder.getContext(), t);
819 ContractionOp::getDefaultKind());
825 ArrayAttr iteratorTypes, CombiningKind kind,
826 arith::FastMathFlags fastMathFlags) {
829 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
830 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
832 CombiningKindAttr::get(builder.
getContext(), kind));
833 if (fastMathFlags != arith::FastMathFlags::none)
835 getFastmathAttrName(
result.name),
836 arith::FastMathFlagsAttr::get(builder.
getContext(), fastMathFlags));
847 DictionaryAttr dictAttr;
861 result.attributes.append(dictAttr.getValue().begin(),
862 dictAttr.getValue().end());
868 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
869 result.attributes.get(getIteratorTypesAttrName(
result.name)));
870 if (!iteratorTypes) {
872 <<
"expected " << getIteratorTypesAttrName(
result.name)
873 <<
" array attribute";
878 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
879 auto maybeIteratorType = symbolizeIteratorType(s);
880 if (!maybeIteratorType.has_value())
881 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
883 iteratorTypeAttrs.push_back(
884 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
886 result.attributes.set(getIteratorTypesAttrName(
result.name),
889 if (!
result.attributes.get(getKindAttrName(
result.name))) {
891 getKindAttrName(
result.name),
892 CombiningKindAttr::get(
result.getContext(),
893 ContractionOp::getDefaultKind()));
895 if (masksInfo.empty())
897 if (masksInfo.size() != 2)
899 "expected zero or exactly 2 vector mask operands");
900 auto lhsType = llvm::cast<VectorType>(types[0]);
901 auto rhsType = llvm::cast<VectorType>(types[1]);
903 std::array<VectorType, 2> maskTypes = {
913 auto attrNames = getTraitAttrNames();
915 traitAttrsSet.insert_range(attrNames);
917 for (
auto attr : (*this)->getAttrs()) {
918 if (attr.getName() == getIteratorTypesAttrName()) {
920 llvm::cast<ArrayAttr>(attr.getValue())
921 .getAsValueRange<IteratorTypeAttr, IteratorType>();
927 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) ->
Attribute {
928 return StringAttr::get(
getContext(), stringifyIteratorType(t));
931 attrs.emplace_back(getIteratorTypesAttrName(),
932 ArrayAttr::get(
getContext(), iteratorTypeNames));
933 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
935 if (attr.getName() == getFastmathAttrName() &&
936 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
937 arith::FastMathFlags::none)
939 attrs.push_back(attr);
943 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
944 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
945 p << getRhs() <<
", " << getAcc();
948 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
953 const std::vector<std::pair<int64_t, int64_t>> &map) {
954 for (
auto &dimPair : map) {
955 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
956 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
957 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
964 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
966 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
967 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
970 for (
auto &dimPair : contractingDimMap) {
971 lhsContractingDimSet.insert(dimPair.first);
972 rhsContractingDimSet.insert(dimPair.second);
975 llvm::make_second_range(batchDimMap));
979 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
980 if (lhsContractingDimSet.count(i) > 0)
982 expectedResultDims.push_back(lhsType.getDimSize(i));
986 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
987 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
989 expectedResultDims.push_back(rhsType.getDimSize(i));
993 if (expectedResultDims.empty()) {
995 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
996 return op.emitOpError(
"invalid accumulator/result vector shape");
999 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
1000 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
1001 if (!resVectorType || !accVectorType)
1002 return op.emitOpError(
"invalid accumulator/result vector shape");
1008 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1009 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1011 return op.emitOpError(
1012 "expected all dimensions to be either a LHS or a RHS dimension");
1015 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1016 VectorType v = pair.first;
1017 auto map = pair.second;
1018 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1019 unsigned pos = map.getDimPosition(idx);
1024 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1025 return op.emitOpError(
"expected all dimensions to get an extent as "
1026 "either a LHS or a RHS dimension");
1028 AffineMap resMap = op.getIndexingMapsArray()[2];
1033 assert(llvm::all_of(expectedMap.
getResults(),
1034 llvm::IsaPred<AffineConstantExpr>) &&
1035 "expected constant extent along all dimensions.");
1037 auto expectedShape =
1039 return cast<AffineConstantExpr>(e).getValue();
1042 VectorType::get(expectedShape, resVectorType.getElementType(),
1043 resVectorType.getScalableDims());
1044 if (resVectorType != expected || accVectorType != expected)
1045 return op.emitOpError(
1046 "invalid accumulator/result vector shape, expected: ")
1052LogicalResult ContractionOp::verify() {
1053 VectorType lhsType = getLhsType();
1054 VectorType rhsType = getRhsType();
1055 Type accType = getAccType();
1056 Type resType = getResultType();
1058 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1059 if (!lhsType.getElementType().isSignlessInteger())
1060 return emitOpError(
"only supports signless integer types");
1064 if (getIndexingMapsArray().size() != 3)
1065 return emitOpError(
"expected an indexing map for each vector operand");
1070 unsigned numIterators = getIteratorTypes().getValue().size();
1071 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1072 auto index = it.index();
1073 auto map = it.value();
1074 if (map.getNumSymbols() != 0)
1076 <<
index <<
" to have no symbols";
1077 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1078 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1081 if (map.getNumDims() != numIterators)
1083 <<
index <<
" to have " << numIterators <<
" number of inputs";
1084 if (map.getNumResults() != rank)
1086 <<
index <<
" to have " << rank <<
" number of outputs";
1087 if (!map.isProjectedPermutation())
1089 <<
index <<
" to be a projected permutation of its inputs";
1092 auto contractingDimMap = getContractingDimMap();
1093 auto batchDimMap = getBatchDimMap();
1096 if (contractingDimMap.empty())
1097 return emitOpError(
"expected at least one contracting dimension pair");
1100 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1101 return emitOpError(
"invalid contracting dimension map");
1105 return emitOpError(
"invalid batch dimension map");
1109 contractingDimMap, batchDimMap)))
1112 if (!getKindAttr()) {
1113 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1114 "'vector.kind<add>')");
1118 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1119 auto elementType = vectorType ? vectorType.getElementType() : resType;
1121 return emitOpError(
"unsupported contraction type");
1124 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1131Type ContractionOp::getExpectedMaskType() {
1132 auto indexingMaps = this->getIndexingMapsArray();
1135 VectorType lhsType = this->getLhsType();
1136 VectorType rhsType = this->getRhsType();
1138 unsigned numVecDims = lhsIdxMap.
getNumDims();
1144 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1147 lhsType.getScalableDims()[dimIdx];
1149 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1152 rhsType.getScalableDims()[dimIdx];
1155 assert(ShapedType::isStaticShape(maskShape) &&
1156 "Mask shape couldn't be computed");
1158 return VectorType::get(maskShape,
1159 IntegerType::get(lhsType.getContext(), 1),
1160 maskShapeScalableDims);
1165 getIteratorTypesAttrName(), getKindAttrName(),
1166 getFastmathAttrName()};
1176static std::vector<std::pair<int64_t, int64_t>>
1178 IteratorType targetIteratorType,
MLIRContext *context) {
1179 std::vector<std::pair<int64_t, int64_t>> dimMap;
1180 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1181 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1182 if (iteratorType != targetIteratorType)
1188 if (lhsDim >= 0 && rhsDim >= 0)
1189 dimMap.emplace_back(lhsDim, rhsDim);
1194void ContractionOp::getIterationBounds(
1196 auto lhsShape = getLhsType().getShape();
1197 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1199 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1202 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1203 if (iteratorType == IteratorType::reduction) {
1206 assert(lhsDimIndex >= 0);
1207 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1212 assert(resDimIndex >= 0);
1213 assert(resVectorType !=
nullptr);
1214 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1218void ContractionOp::getIterationIndexMap(
1220 unsigned numMaps = getIndexingMapsArray().size();
1221 iterationIndexMap.resize(numMaps);
1222 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1223 auto index = it.index();
1224 auto map = it.value();
1225 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1226 auto dim = cast<AffineDimExpr>(map.getResult(i));
1227 iterationIndexMap[
index][dim.getPosition()] = i;
1232std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1234 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1238std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1240 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1244std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1246 getIterationBounds(
shape);
1268template <
typename AddOpType>
1274 auto canonicalize = [&](
Value maybeContraction,
1275 Value otherOperand) -> vector::ContractionOp {
1276 vector::ContractionOp contractionOp =
1277 dyn_cast_or_null<vector::ContractionOp>(
1280 return vector::ContractionOp();
1281 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1282 contractionOp.getAcc().getDefiningOp())) {
1283 if (maybeZero.getValue() ==
1284 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1286 bvm.
map(contractionOp.getAcc(), otherOperand);
1287 auto newContraction =
1288 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1289 rewriter.
replaceOp(addOp, newContraction.getResult());
1290 return newContraction;
1293 return vector::ContractionOp();
1296 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1297 vector::ContractionOp
contract = canonicalize(a,
b);
1322 setResultRanges(getResult(), argRanges.front());
1327 auto vectorTy = cast<VectorType>(source.
getType());
1352 build(builder,
result, source, dynamicPos,
1357ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1358 ExtractOp::Adaptor adaptor,
1360 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1361 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1362 vectorType.getRank()) {
1363 inferredReturnTypes.push_back(vectorType.getElementType());
1365 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1366 vectorType.getRank());
1367 inferredReturnTypes.push_back(VectorType::get(
1368 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1369 vectorType.getScalableDims().drop_front(n)));
1374LogicalResult vector::ExtractOp::verify() {
1375 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1376 if (resTy.getRank() == 0)
1378 "expected a scalar instead of a 0-d vector as the result type");
1381 auto dynamicMarkersCount =
1382 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1383 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1385 "mismatch between dynamic and static positions (kDynamic marker but no "
1386 "corresponding dynamic position) -- this can only happen due to an "
1387 "incorrect fold/rewrite");
1388 auto position = getMixedPosition();
1389 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1391 "expected position attribute of rank no greater than vector rank");
1392 for (
auto [idx, pos] : llvm::enumerate(position)) {
1393 if (
auto attr = dyn_cast<Attribute>(pos)) {
1394 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1396 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1397 return emitOpError(
"expected position attribute #")
1399 <<
" to be a non-negative integer smaller than the "
1400 "corresponding vector dimension or poison (-1)";
1407template <
typename IntType>
1409 return llvm::map_to_vector<4>(
1410 arrayAttr.getAsRange<IntegerAttr>(),
1411 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1417 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1421 if (extractOp.hasDynamicPosition())
1425 ExtractOp currentOp = extractOp;
1427 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1428 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1431 if (currentOp.hasDynamicPosition())
1434 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1436 extractOp.setOperand(0, currentOp.getSource());
1439 std::reverse(globalPosition.begin(), globalPosition.end());
1440 extractOp.setStaticPosition(globalPosition);
1452class ExtractFromInsertTransposeChainState {
1454 ExtractFromInsertTransposeChainState(ExtractOp e);
1463 template <
typename ContainerA,
typename ContainerB>
1464 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1465 return a.size() <=
b.size() &&
1466 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1473 template <
typename ContainerA,
typename ContainerB>
1474 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1475 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1476 if (elemA < 0 || elemB < 0)
1487 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1491 void updateStateForNextIteration(Value v) {
1498 LogicalResult handleTransposeOp();
1501 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1516 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1521 Value tryToFoldExtractOpInPlace(Value source);
1523 ExtractOp extractOp;
1525 int64_t extractedRank;
1527 InsertOp nextInsertOp;
1528 TransposeOp nextTransposeOp;
1538 SmallVector<int64_t> sentinels;
1539 SmallVector<int64_t> extractPosition;
1543ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1545 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1546 extractedRank(extractOp.getNumIndices()) {
1547 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1548 sentinels.reserve(vectorRank - extractedRank);
1549 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1550 sentinels.push_back(-(i + 1));
1552 extractOp.getStaticPosition().end());
1558LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1560 if (extractOp.hasDynamicPosition())
1563 if (!nextTransposeOp)
1566 nextTransposeOp.getPermutation(), extractOp.getContext()));
1573ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1576 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1579 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1580 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1583 res = nextInsertOp.getValueToStore();
1592ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1594 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1597 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1607 res = nextInsertOp.getValueToStore();
1615Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1618 if (extractOp.hasDynamicPosition())
1622 bool nothingToFold = (source == extractOp.getSource());
1623 if (nothingToFold || !canFold())
1627 OpBuilder
b(extractOp.getContext());
1628 extractOp.setStaticPosition(
1630 extractOp.getSourceMutable().assign(source);
1631 return extractOp.getResult();
1635Value ExtractFromInsertTransposeChainState::fold() {
1637 if (extractOp.hasDynamicPosition())
1640 Value valueToExtractFrom = extractOp.getSource();
1641 updateStateForNextIteration(valueToExtractFrom);
1642 while (nextInsertOp || nextTransposeOp) {
1645 if (succeeded(handleTransposeOp())) {
1646 valueToExtractFrom = nextTransposeOp.getVector();
1647 updateStateForNextIteration(valueToExtractFrom);
1653 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1658 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1659 return tryToFoldExtractOpInPlace(
result);
1663 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1669 valueToExtractFrom = nextInsertOp.getDest();
1670 updateStateForNextIteration(valueToExtractFrom);
1673 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1678 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1679 auto vecType = dyn_cast<VectorType>(type);
1680 return vecType && vecType.getRank() == 0;
1690 if (isa<BroadcastOp>(op))
1693 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1701 VectorType srcType = shapeCast.getSourceVectorType();
1703 uint64_t srcRank = srcType.getRank();
1705 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1731 Operation *defOp = extractOp.getSource().getDefiningOp();
1738 if (extractOp.getType() == input.
getType())
1744 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1745 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1746 unsigned inputRank = inputType ? inputType.getRank() : 0;
1747 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1748 unsigned extractRank = extractType ? extractType.getRank() : 0;
1751 if (extractRank > inputRank)
1755 assert(inputType &&
"input must be a vector type because of previous checks");
1764 extractType.getShape() != inputShape.take_back(extractRank))
1769 unsigned deltaOverall = inputRank - extractRank;
1770 unsigned deltaBroadcast = broadcastRank - inputRank;
1774 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1775 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1778 extractOp->setOperands(
1779 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1780 extractOp.setStaticPosition(staticPos);
1781 return extractOp.getResult();
1797 if (extractOp.hasDynamicPosition())
1800 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1805 if (shuffleOp.getResultVectorType().getRank() != 1)
1808 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1809 auto shuffleMask = shuffleOp.getMask();
1810 int64_t extractIdx = extractOp.getStaticPosition()[0];
1811 int64_t shuffleIdx = shuffleMask[extractIdx];
1814 if (shuffleIdx < inputVecSize) {
1815 extractOp.setOperand(0, shuffleOp.getV1());
1816 extractOp.setStaticPosition({shuffleIdx});
1818 extractOp.setOperand(0, shuffleOp.getV2());
1819 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1822 return extractOp.getResult();
1828 if (extractOp.hasDynamicPosition())
1831 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1836 auto getDimReverse = [](VectorType type,
int64_t n) {
1837 return type.getShape().take_back(n + 1).front();
1840 llvm::isa<VectorType>(extractOp.getType())
1841 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1843 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1845 if (destinationRank > 0) {
1846 auto destinationType =
1847 llvm::cast<VectorType>(extractOp.getResult().getType());
1848 for (
int64_t i = 0; i < destinationRank; i++) {
1852 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1853 getDimReverse(destinationType, i))
1860 std::reverse(extractedPos.begin(), extractedPos.end());
1863 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1864 strides.push_back(stride);
1866 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1874 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1876 for (
int64_t i = 0; i < numDimension; i++) {
1877 newStrides.push_back(stride);
1879 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1881 std::reverse(newStrides.begin(), newStrides.end());
1885 extractOp.setStaticPosition(newPosition);
1886 extractOp.setOperand(0, shapeCastOp.getSource());
1887 return extractOp.getResult();
1893 if (extractOp.hasDynamicPosition())
1896 auto extractStridedSliceOp =
1897 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1898 if (!extractStridedSliceOp)
1907 if (extractStridedSliceOp.hasNonUnitStrides())
1913 while (!sliceOffsets.empty()) {
1914 size_t lastOffset = sliceOffsets.size() - 1;
1915 if (sliceOffsets.back() != 0 ||
1916 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1917 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1919 sliceOffsets.pop_back();
1921 unsigned destinationRank = 0;
1922 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1923 destinationRank = vecType.getRank();
1926 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1927 sliceOffsets.size())
1931 assert(extractedPos.size() >= sliceOffsets.size());
1932 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1933 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1934 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1938 extractOp.setStaticPosition(extractedPos);
1939 return extractOp.getResult();
1945 if (extractOp.hasDynamicPosition())
1949 llvm::isa<VectorType>(extractOp.getType())
1950 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1952 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1962 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1963 insertOp.getSourceVectorType().getRank();
1964 if (destinationRank > insertOp.getSourceVectorType().getRank())
1969 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1970 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1973 bool disjoint =
false;
1975 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1976 int64_t start = insertOffsets[dim];
1978 (dim < insertRankDiff)
1980 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1982 int64_t offset = extractOffsets[dim];
1984 if (start <= offset && offset < end) {
1985 if (dim >= insertRankDiff)
1986 offsetDiffs.push_back(offset - start);
1997 insertOp.getSourceVectorType().getRank() - destinationRank;
1998 for (
int64_t i = 0; i < destinationRank; i++) {
1999 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
2000 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2004 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2007 extractOp.setStaticPosition(offsetDiffs);
2008 return extractOp.getResult();
2012 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2025 if (extractOp.hasDynamicPosition())
2029 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2030 if (!fromElementsOp)
2034 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2035 if (vecType.isScalable())
2039 int64_t rank = vecType.getRank();
2041 if (extractOp.getType() != vecType.getElementType())
2044 "unexpected number of indices");
2049 for (
int i = rank - 1; i >= 0; --i) {
2050 flatIndex +=
indices[i] * stride;
2051 stride *= vecType.getDimSize(i);
2053 return fromElementsOp.getElements()[flatIndex];
2058template <
typename OpType,
typename AdaptorType>
2061 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2062 OperandRange dynamicPosition = op.getDynamicPosition();
2065 if constexpr (std::is_same_v<OpType, ExtractOp>)
2066 vectorShape = op.getSourceVectorType().getShape();
2071 if (!dynamicPosition.size())
2078 bool opChange =
false;
2079 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2080 if (ShapedType::isStatic(staticPosition[i]))
2084 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2085 int64_t value = attr.getInt();
2089 staticPosition[i] = attr.getInt();
2094 operands.push_back(position);
2098 op.setStaticPosition(staticPosition);
2099 op.getOperation()->setOperands(operands);
2101 return op.getResult();
2111 if (!is_contained(staticPos, poisonVal))
2114 return ub::PoisonAttr::get(context);
2128 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2133 if (denseAttr.isSplat()) {
2135 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2140 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2141 if (vecTy.isScalable())
2144 if (extractOp.hasDynamicPosition()) {
2159 copy(extractOp.getStaticPosition(), completePositions.begin());
2162 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2165 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2167 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2170 newAttr = *denseValuesBegin;
2176OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2180 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2187 SmallVector<Value> operands = {getSource()};
2191 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2197 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2212 return inplaceFolded;
2218class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2222 LogicalResult matchAndRewrite(ExtractOp extractOp,
2223 PatternRewriter &rewriter)
const override {
2226 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2232 BroadcastableToResult::Success)
2241class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2245 LogicalResult matchAndRewrite(ExtractOp extractOp,
2246 PatternRewriter &rewriter)
const override {
2248 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2252 VectorType extractedMaskType =
2253 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2255 if (!extractedMaskType)
2258 auto maskOperands = createMaskOp.getOperands();
2259 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2260 VectorType maskType = createMaskOp.getVectorType();
2262 bool containsUnknownDims =
false;
2265 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2267 int64_t pos = extractOpPos[dimIdx];
2268 Value operand = maskOperands[dimIdx];
2269 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2272 containsUnknownDims =
true;
2276 int64_t createMaskBound =
2277 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2279 if (pos != ShapedType::kDynamic) {
2282 allFalse |= pos >= createMaskBound;
2283 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2287 containsUnknownDims =
true;
2294 }
else if (!containsUnknownDims) {
2296 extractOp, extractedMaskType,
2297 maskOperands.drop_front(extractOpPos.size()));
2306class ExtractOpFromConstantMask final :
public OpRewritePattern<ExtractOp> {
2310 LogicalResult matchAndRewrite(ExtractOp extractOp,
2311 PatternRewriter &rewriter)
const override {
2312 auto constantMaskOp =
2313 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2314 if (!constantMaskOp)
2317 Type resultType = extractOp.getResult().getType();
2318 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2320 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2321 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2323 VectorType maskType = constantMaskOp.getVectorType();
2326 for (
size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2327 int64_t pos = extractOpPos[dimIdx];
2328 if (pos == ShapedType::kDynamic) {
2331 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2340 if (pos >= maskDimSizes[dimIdx]) {
2341 if (extractedMaskType) {
2353 if (extractedMaskType) {
2357 extractOp, extractedMaskType,
2358 maskDimSizes.drop_front(extractOpPos.size()));
2371LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2372 PatternRewriter &rewriter) {
2373 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2377 VectorType sourceType = castOp.getSourceVectorType();
2378 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2382 if (sourceType.getNumElements() != targetType.getNumElements())
2386 castOp.getSource());
2396LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2397 PatternRewriter &rewriter) {
2399 if (extractOp.hasDynamicPosition())
2403 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2408 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2409 if (!fromElementsOp)
2411 VectorType inputType = fromElementsOp.getType();
2414 if (resultType.isScalable() || inputType.isScalable())
2419 SmallVector<int64_t> firstElementPos =
2420 llvm::to_vector(extractOp.getStaticPosition());
2421 firstElementPos.append(resultType.getRank(), 0);
2424 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2425 flatIndex += firstElementPos[i] * stride;
2426 stride *= inputType.getDimSize(i);
2431 extractOp, resultType,
2432 fromElementsOp.getElements().slice(flatIndex,
2433 resultType.getNumElements()));
2445struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2447 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2448 PatternRewriter &rewriter)
const override {
2449 VectorType sourceType = extractOp.getSourceVectorType();
2450 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2454 if (sourceType.getNumElements() != outType.getNumElements())
2456 extractOp,
"extract to vector with fewer elements");
2460 if (llvm::any_of(extractOp.getMixedPosition(),
2461 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2463 "leaving for extract poison folder");
2466 extractOp.getSource());
2487struct FoldExtractFromInsertUnitDim final
2488 : OpRewritePattern<vector::ExtractOp> {
2491 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2492 PatternRewriter &rewriter)
const override {
2493 if (extractOp.hasDynamicPosition())
2496 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2497 if (!insertOp || insertOp.hasDynamicPosition())
2500 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2501 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2504 if (extractPos.size() >= insertPos.size() ||
2505 extractPos != insertPos.take_front(extractPos.size()))
2511 auto srcVecType = extractOp.getSourceVectorType();
2512 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2513 if (srcVecType.getDimSize(i) != 1)
2516 Value
inserted = insertOp.getValueToStore();
2517 Type extractedType = extractOp.getResult().getType();
2518 if (isa<VectorType>(
inserted.getType())) {
2525 extractOp, extractOp.getResult().
getType(),
2526 insertOp.getValueToStore());
2534void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2535 MLIRContext *context) {
2536 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2537 ExtractOpFromConstantMask, ExtractToShapeCast,
2538 FoldExtractFromInsertUnitDim>(context);
2539 results.
add(foldExtractFromShapeCastToShapeCast);
2540 results.
add(foldExtractFromFromElements);
2545 for (
auto attr : arrayAttr)
2546 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2553std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2564 if (operands.empty())
2567 return llvm::all_of(operands, [&](
Value operand) {
2569 return currentDef == defOp;
2587 auto fromElementsOp =
2588 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2589 if (!fromElementsOp)
2592 llvm::append_range(results, fromElementsOp.getElements());
2609 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2613 if (isa<VectorType>(bcastOp.getSource().getType()))
2616 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2618 Value scalar = bcastOp.getSource();
2619 results.assign(resultVecType.getNumElements(), scalar);
2623LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2624 SmallVectorImpl<OpFoldResult> &results) {
2629 if (
auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2630 setOperand(shapeCast.getSource());
2638ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2639 ToElementsOp::Adaptor adaptor,
2640 SmallVectorImpl<Type> &inferredReturnTypes) {
2641 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2642 Type elType = vecType.getElementType();
2643 inferredReturnTypes.append(vecType.getNumElements(), elType);
2665 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2670 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2674 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2679 int64_t dstRank = dstShape.size();
2680 int64_t srcRank = srcShape.size();
2683 auto srcElems = vector::ToElementsOp::create(
2684 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2686 int64_t dstCount = llvm::product_of(dstShape);
2689 replacements.reserve(dstCount);
2714 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2717 for (
int64_t k = 0; k < srcRank; ++k)
2718 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2721 replacements.push_back(srcElems.getResult(srcLin));
2724 rewriter.
replaceOp(toElementsOp, replacements);
2729void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2730 MLIRContext *context) {
2731 results.
add<ToElementsOfBroadcast>(context);
2751 OperandRange fromElemsOperands = fromElementsOp.getElements();
2752 if (fromElemsOperands.empty())
2755 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2763 Value toElementsInput = toElementsOp.getSource();
2764 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2765 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2766 return toElementsInput;
2786 if (llvm::any_of(elements, [](
Attribute attr) {
2792 auto destVecType = fromElementsOp.getDest().getType();
2793 auto destEltType = destVecType.getElementType();
2794 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2799 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2806OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2823 if (!llvm::all_equal(fromElementsOp.getElements()))
2826 fromElementsOp, fromElementsOp.getType(),
2827 fromElementsOp.getElements().front());
2855 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2859 if (fromElements.getType().getNumElements() == 1)
2870 for (
auto [insertIndex, element] :
2871 llvm::enumerate(fromElements.getElements())) {
2874 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2877 "element not from vector.extract");
2882 if (insertIndex == 0) {
2883 source = extractOp.getSource();
2884 }
else if (extractOp.getSource() != source) {
2886 "element from different vector");
2890 int64_t rank = position.size();
2891 assert(rank == source.getType().getRank() &&
2892 "scalar extract must have full rank position");
2903 if (insertIndex == 0) {
2904 const int64_t numElms = fromElements.getType().getNumElements();
2907 while (
index > 0 && position[
index - 1] == 0 &&
2908 numSuffixElms < numElms) {
2909 numSuffixElms *= source.getType().getDimSize(
index - 1);
2912 if (numSuffixElms != numElms) {
2914 fromElements,
"elements do not form a suffix of source");
2916 expectedPosition = llvm::to_vector(position);
2917 combinedPosition = position.drop_back(rank -
index);
2921 else if (expectedPosition != position) {
2923 fromElements,
"elements not in ascending order (static order)");
2925 increment(expectedPosition, source.getType().getShape());
2928 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2929 fromElements.getLoc(), source, combinedPosition);
2932 fromElements, fromElements.getType(), extracted);
2940 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2959void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2961 setResultRanges(getResult(), argRanges.front());
2964std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2965 return llvm::to_vector<4>(getResultVectorType().
getShape());
2970static llvm::SetVector<int64_t>
2973 int64_t rankDiff = dstShape.size() - srcShape.size();
2976 for (
auto [s1, s2] :
2977 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2979 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2987llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2989 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2992 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
3008Value BroadcastOp::createOrFoldBroadcastOp(
3009 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
3010 const llvm::SetVector<int64_t> &broadcastedDims) {
3011 assert(!dstShape.empty() &&
"unexpected empty dst shape");
3014 SmallVector<int64_t> checkShape;
3015 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
3016 if (broadcastedDims.contains(i))
3018 checkShape.push_back(dstShape[i]);
3020 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3021 "ill-formed broadcastedDims contains values not confined to "
3024 Location loc = value.
getLoc();
3026 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
3027 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3030 if (!srcVectorType) {
3031 assert(checkShape.empty() &&
3032 "ill-formed createOrFoldBroadcastOp arguments");
3033 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3036 assert(srcVectorType.getShape().equals(checkShape) &&
3037 "ill-formed createOrFoldBroadcastOp arguments");
3047 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3048 broadcastShape.reserve(dstShape.size());
3064 int64_t nextSrcShapeDim = broadcastedDims.size();
3065 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3066 if (broadcastedDims.contains(i)) {
3071 broadcastShape.push_back(dstShape[i]);
3072 permutation[i] = broadcastShape.size() - 1;
3078 permutation[i] = nextSrcShapeDim++;
3082 llvm::append_range(broadcastShape, srcVectorType.getShape());
3087 "unexpected \"dim-1\" broadcast");
3089 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3091 vector::BroadcastableToResult::Success &&
3092 "must be broadcastable");
3093 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3096 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3097 if (permutation[i] != i)
3098 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3104 Type srcType, VectorType dstVectorType,
3105 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3107 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3111 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3115 int64_t srcRank = srcVectorType.getRank();
3116 int64_t dstRank = dstVectorType.getRank();
3117 if (srcRank > dstRank)
3121 int64_t lead = dstRank - srcRank;
3122 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3125 bool foundMismatchingDims =
false;
3128 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3129 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3130 if (srcDim != 1 && srcDim != dstDim)
3131 foundMismatchingDims =
true;
3134 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3135 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3136 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3139 (srcDimScalableFlag != dstDimScalableFlag &&
3140 (srcDim != 1 || srcDimScalableFlag)))
3141 foundMismatchingDims =
true;
3143 if (foundMismatchingDims) {
3144 if (mismatchingDims !=
nullptr) {
3145 mismatchingDims->first.dim = srcDim;
3146 mismatchingDims->first.isScalable = srcDimScalableFlag;
3148 mismatchingDims->second.dim = dstDim;
3149 mismatchingDims->second.isScalable = dstDimScalableFlag;
3158LogicalResult BroadcastOp::verify() {
3159 std::pair<VectorDim, VectorDim> mismatchingDims;
3161 getSourceType(), getResultVectorType(), &mismatchingDims);
3165 return emitOpError(
"source rank higher than destination rank");
3168 << (mismatchingDims.first.isScalable ?
"[" :
"")
3169 << mismatchingDims.first.dim
3170 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
3171 << (mismatchingDims.second.isScalable ?
"[" :
"")
3172 << mismatchingDims.second.dim
3173 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3176 return emitOpError(
"source type is not a vector");
3177 llvm_unreachable(
"unexpected vector.broadcast op error");
3184 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3188 VectorType srcType = srcShapeCast.getSourceVectorType();
3189 VectorType destType = broadcastOp.getResultVectorType();
3197 srcShapeCast.getResultVectorType().getShape();
3200 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3201 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3202 shapecastShape.take_back(numTrailingDims)))
3205 assert(all_of(srcShape.drop_back(numTrailingDims),
3206 [](
int64_t E) { return E == 1; }) &&
3207 all_of(shapecastShape.drop_back(numTrailingDims),
3208 [](
int64_t E) { return E == 1; }) &&
3209 "ill-formed shape_cast");
3211 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3215OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3216 if (getSourceType() == getResultVectorType())
3221 if (!adaptor.getSource())
3223 auto vectorType = getResultVectorType();
3224 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3225 if (vectorType.getElementType() != attr.getType())
3229 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3230 if (vectorType.getElementType() != attr.getType())
3234 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3244struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3247 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3248 PatternRewriter &rewriter)
const override {
3249 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3253 broadcastOp.getResultVectorType(),
3254 srcBroadcast.getSource());
3267struct BroadcastToShapeCast final
3268 :
public OpRewritePattern<vector::BroadcastOp> {
3270 LogicalResult matchAndRewrite(vector::BroadcastOp
broadcast,
3271 PatternRewriter &rewriter)
const override {
3273 auto sourceType = dyn_cast<VectorType>(
broadcast.getSourceType());
3276 broadcast,
"source is a scalar, shape_cast doesn't support scalar");
3280 if (sourceType.getNumElements() != outType.getNumElements()) {
3282 broadcast,
"broadcast to a greater number of elements");
3292void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3293 MLIRContext *context) {
3294 results.
add<BroadcastFolder, BroadcastToShapeCast>(context);
3301LogicalResult ShuffleOp::verify() {
3302 VectorType resultType = getResultVectorType();
3303 VectorType v1Type = getV1VectorType();
3304 VectorType v2Type = getV2VectorType();
3306 int64_t resRank = resultType.getRank();
3307 int64_t v1Rank = v1Type.getRank();
3308 int64_t v2Rank = v2Type.getRank();
3309 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3310 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3311 if (!wellFormed0DCase && !wellFormedNDCase)
3315 for (int64_t r = 1; r < v1Rank; ++r) {
3316 int64_t resDim = resultType.getDimSize(r);
3317 int64_t v1Dim = v1Type.getDimSize(r);
3318 int64_t v2Dim = v2Type.getDimSize(r);
3319 if (resDim != v1Dim || v1Dim != v2Dim)
3323 ArrayRef<int64_t> mask = getMask();
3324 int64_t maskLength = mask.size();
3325 if (maskLength <= 0)
3327 if (maskLength != resultType.getDimSize(0))
3330 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3331 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3332 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3334 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3340ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3341 ShuffleOp::Adaptor adaptor,
3342 SmallVectorImpl<Type> &inferredReturnTypes) {
3343 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3347 auto v1Rank = v1Type.getRank();
3350 SmallVector<int64_t, 4> shape;
3351 shape.reserve(v1Rank);
3352 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3355 llvm::append_range(shape, v1Type.getShape().drop_front());
3356 inferredReturnTypes.push_back(
3357 VectorType::get(shape, v1Type.getElementType()));
3361template <
typename T>
3364 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3365 return value == expected++;
3372 auto v1Type = op.getV1VectorType();
3373 auto v2Type = op.getV2VectorType();
3374 auto mask = op.getMask();
3387 if (!isV1Poison && !isV2Poison)
3390 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3391 bool changed =
false;
3393 for (
int64_t &idx : newMask) {
3394 if (idx == ShuffleOp::kPoisonIndex)
3396 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3397 idx = ShuffleOp::kPoisonIndex;
3405 op.setMask(newMask);
3406 return op.getResult();
3415 return ub::PoisonAttr::get(context);
3422 auto v1Type = op.getV1VectorType();
3423 if (v1Type.getRank() != 1)
3435 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3438 v2Elements = to_vector(v2DenseAttr.getValues<
Attribute>());
3439 poisonElement = v2Elements[0];
3442 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3445 v1Elements = to_vector(v1DenseAttr.getValues<
Attribute>());
3446 poisonElement = v1Elements[0];
3451 int64_t v1Size = v1Type.getDimSize(0);
3452 for (
int64_t maskIdx : mask) {
3455 if (maskIdx == ShuffleOp::kPoisonIndex) {
3456 indexedElm = poisonElement;
3458 if (maskIdx < v1Size)
3459 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3461 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3464 results.push_back(indexedElm);
3470OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3471 auto v1Type = getV1VectorType();
3473 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3474 "Vector shuffle does not support scalable vectors");
3478 if (v1Type.getRank() == 0)
3486 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3487 if (!v1Attr || !v2Attr)
3502struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3505 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3506 PatternRewriter &rewriter)
const override {
3507 VectorType v1VectorType = shuffleOp.getV1VectorType();
3508 ArrayRef<int64_t> mask = shuffleOp.getMask();
3509 if (v1VectorType.getRank() > 0)
3511 if (mask.size() != 1)
3513 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3531static Value getScalarSplatSource(Value value) {
3537 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3544 if (isa<VectorType>(
broadcast.getSourceType()))
3552class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3556 LogicalResult matchAndRewrite(ShuffleOp op,
3557 PatternRewriter &rewriter)
const override {
3558 Value splat = getScalarSplatSource(op.getV1());
3559 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3569class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3573 LogicalResult matchAndRewrite(ShuffleOp op,
3574 PatternRewriter &rewriter)
const override {
3575 VectorType resultType = op.getResultVectorType();
3576 if (resultType.isScalable())
3578 op,
"ShuffleOp can't represent a scalable interleave");
3580 if (resultType.getRank() != 1)
3582 op,
"ShuffleOp can't represent an n-D interleave");
3584 VectorType sourceType = op.getV1VectorType();
3585 if (sourceType != op.getV2VectorType() ||
3586 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3588 op,
"ShuffleOp types don't match an interleave");
3591 ArrayRef<int64_t> shuffleMask = op.getMask();
3592 int64_t resultVectorSize = resultType.getNumElements();
3593 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3594 int64_t maskValueA = shuffleMask[i * 2];
3595 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3596 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3598 "ShuffleOp mask not interleaving");
3614class FoldUnusedShuffleOperand final :
public OpRewritePattern<ShuffleOp> {
3618 LogicalResult matchAndRewrite(ShuffleOp op,
3619 PatternRewriter &rewriter)
const override {
3621 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3622 return mask == ShuffleOp::kPoisonIndex;
3629 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3632 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3641 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3642 ? op.getV1VectorType().getDimSize(0)
3644 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3645 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3647 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3651 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3652 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3654 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3662void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3663 MLIRContext *context) {
3664 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3665 FoldUnusedShuffleOperand>(context);
3672void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3674 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3677void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3678 Value source, Value dest) {
3679 auto vectorTy = cast<VectorType>(dest.
getType());
3680 build(builder,
result, source, dest,
3681 SmallVector<int64_t>(vectorTy.getRank(), 0));
3684void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3685 Value source, Value dest, int64_t position) {
3686 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3689void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3690 Value source, Value dest, OpFoldResult position) {
3691 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3694void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3695 Value source, Value dest,
3696 ArrayRef<int64_t> position) {
3697 SmallVector<OpFoldResult> posVals;
3698 posVals.reserve(position.size());
3699 llvm::transform(position, std::back_inserter(posVals),
3701 build(builder,
result, source, dest, posVals);
3704void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3705 Value source, Value dest,
3706 ArrayRef<OpFoldResult> position) {
3707 SmallVector<int64_t> staticPos;
3708 SmallVector<Value> dynamicPos;
3710 build(builder,
result, source, dest, dynamicPos,
3714LogicalResult InsertOp::verify() {
3715 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3716 if (srcTy.getRank() == 0)
3718 "expected a scalar instead of a 0-d vector as the source operand");
3720 SmallVector<OpFoldResult> position = getMixedPosition();
3721 auto destVectorType = getDestVectorType();
3722 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3724 "expected position attribute of rank no greater than dest vector rank");
3725 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3726 if (srcVectorType &&
3727 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3728 static_cast<unsigned>(destVectorType.getRank())))
3729 return emitOpError(
"expected position attribute rank + source rank to "
3730 "match dest vector rank");
3731 if (!srcVectorType &&
3732 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3734 "expected position attribute rank to match the dest vector rank");
3735 for (
auto [idx, pos] : llvm::enumerate(position)) {
3736 if (
auto attr = dyn_cast<Attribute>(pos)) {
3737 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3739 destVectorType.getDimSize(idx))) {
3740 return emitOpError(
"expected position attribute #")
3742 <<
" to be a non-negative integer smaller than the "
3744 "dest vector dimension";
3757 assert(positions.size() <= completePositions.size() &&
3758 "positions size must be less than or equal to destTy rank");
3759 copy(positions, completePositions.begin());
3767class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3771 LogicalResult matchAndRewrite(InsertOp insertOp,
3772 PatternRewriter &rewriter)
const override {
3774 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3775 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3776 srcVecType.getNumElements())
3779 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3785class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3789 LogicalResult matchAndRewrite(InsertOp op,
3790 PatternRewriter &rewriter)
const override {
3792 Value splat = getScalarSplatSource(op.getValueToStore());
3793 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3821class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3824 LogicalResult matchAndRewrite(InsertOp op,
3825 PatternRewriter &rewriter)
const override {
3827 VectorType destTy = op.getDestVectorType();
3828 if (destTy.isScalable())
3831 for (Operation *user : op.getResult().getUsers())
3832 if (
auto insertOp = dyn_cast<InsertOp>(user))
3833 if (insertOp.getDest() == op.getResult())
3836 InsertOp currentOp = op;
3837 SmallVector<InsertOp> chainInsertOps;
3840 if (currentOp.hasDynamicPosition())
3843 chainInsertOps.push_back(currentOp);
3844 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3847 if (currentOp && !currentOp->hasOneUse())
3851 int64_t vectorSize = destTy.getNumElements();
3852 int64_t initializedCount = 0;
3853 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3854 SmallVector<int64_t> pendingInsertPos;
3855 SmallVector<int64_t> pendingInsertSize;
3856 SmallVector<Value> pendingInsertValues;
3858 for (
auto insertOp : chainInsertOps) {
3860 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3864 int64_t insertBeginPosition =
3869 int64_t insertSize = 1;
3870 if (
auto srcVectorType =
3871 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3872 insertSize = srcVectorType.getNumElements();
3874 assert(insertBeginPosition + insertSize <= vectorSize &&
3875 "insert would overflow the vector");
3877 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3878 insertBeginPosition + insertSize)) {
3879 if (initializedDestIdxs[index])
3881 initializedDestIdxs[index] =
true;
3887 pendingInsertPos.push_back(insertBeginPosition);
3888 pendingInsertSize.push_back(insertSize);
3889 pendingInsertValues.push_back(insertOp.getValueToStore());
3891 if (initializedCount == vectorSize)
3896 if (initializedCount != vectorSize)
3899 SmallVector<Value> elements(vectorSize);
3900 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3901 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3902 pendingInsertValues))) {
3903 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3905 if (!srcVectorType) {
3906 elements[insertBeginPosition] = valueToStore;
3910 Repeated<Type> elementToInsertTypes(insertSize,
3911 srcVectorType.getElementType());
3913 auto elementsToInsert = vector::ToElementsOp::create(
3914 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3915 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3916 elements[insertBeginPosition + linearIdx] =
3917 elementsToInsert.getResult(linearIdx);
3931 int64_t maxVectorSizeFoldThreshold) {
3932 if (insertOp.hasDynamicPosition())
3935 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3943 VectorType destTy = insertOp.getDestVectorType();
3944 if (destTy.isScalable())
3948 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3949 !insertOp->hasOneUse())
3954 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3961 Type destEltType = destTy.getElementType();
3965 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3966 for (
auto value : denseSource.getValues<
Attribute>())
3972 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3973 copy(insertedValues, allValues.begin() + insertBeginPosition);
3982 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3986 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3989 insertOp.
setOperand(1, destInsert.getDest());
3990 return insertOp.getResult();
3993void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3994 MLIRContext *context) {
3995 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3996 InsertChainFullyInitialized>(context);
3999OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
4002 constexpr int64_t vectorSizeFoldThreshold = 256;
4006 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
4007 return getValueToStore();
4011 SmallVector<Value> operands = {getValueToStore(), getDest()};
4017 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4020 *
this, adaptor.getValueToStore(), adaptor.getDest(),
4021 vectorSizeFoldThreshold)) {
4025 return inplaceFolded;
4032void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4033 Value source, Value dest,
4034 ArrayRef<int64_t> offsets,
4035 ArrayRef<int64_t> strides) {
4036 result.addOperands({source, dest});
4040 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
4042 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
4047template <
typename OpType>
4051 StringRef attrName) {
4052 if (arrayAttr.size() >
shape.size())
4053 return op.emitOpError(
"expected ")
4054 << attrName <<
" attribute of rank no greater than vector rank";
4061template <
typename OpType>
4065 bool halfOpen =
true) {
4066 for (
auto attr : arrayAttr) {
4067 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4071 if (val < min || val >= upper)
4072 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
4073 <<
min <<
", " << upper <<
")";
4081template <
typename OpType>
4086 for (
auto [
index, attrDimPair] :
4087 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
4088 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4092 if (val < min || val >=
max)
4093 return op.emitOpError(
"expected ")
4094 << attrName <<
" dimension " <<
index <<
" to be confined to ["
4095 <<
min <<
", " <<
max <<
")";
4105template <
typename OpType>
4110 assert(arrayAttr1.size() <=
shape.size());
4111 assert(arrayAttr2.size() <=
shape.size());
4112 for (
auto [
index, it] :
4113 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
4114 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4115 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4119 if (val1 + val2 < 0 || val1 + val2 >=
max)
4120 return op.emitOpError(
"expected sum(")
4121 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
4122 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
4130 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4132 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4135LogicalResult InsertStridedSliceOp::verify() {
4136 auto sourceVectorType = getSourceVectorType();
4137 auto destVectorType = getDestVectorType();
4138 auto offsets = getOffsetsAttr();
4139 auto strides = getStridesAttr();
4140 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
4142 "expected offsets of same size as destination vector rank");
4143 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
4144 return emitOpError(
"expected strides of same size as source vector rank");
4145 if (sourceVectorType.getRank() > destVectorType.getRank())
4147 "expected source rank to be no greater than destination rank");
4149 auto sourceShape = sourceVectorType.getShape();
4150 auto destShape = destVectorType.getShape();
4151 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4152 destShape.size() - sourceShape.size(), 0);
4153 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4154 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4155 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4164 offName,
"source vector shape",
4168 unsigned rankDiff = destShape.size() - sourceShape.size();
4169 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4170 if (sourceVectorType.getScalableDims()[idx] !=
4171 destVectorType.getScalableDims()[idx + rankDiff]) {
4172 return emitOpError(
"mismatching scalable flags (at source vector idx=")
4175 if (sourceVectorType.getScalableDims()[idx]) {
4176 auto sourceSize = sourceShape[idx];
4177 auto destSize = destShape[idx + rankDiff];
4178 if (sourceSize != destSize) {
4181 << (
" to match the corresponding base size from the input "
4183 << sourceSize << (
" vs ") << destSize << (
")");
4193class FoldInsertStridedSliceSplat final
4194 :
public OpRewritePattern<InsertStridedSliceOp> {
4198 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4199 PatternRewriter &rewriter)
const override {
4201 auto dst = insertStridedSliceOp.getDest();
4202 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4203 if (!splat || getScalarSplatSource(dst) != splat)
4206 rewriter.
replaceOp(insertStridedSliceOp, dst);
4213class FoldInsertStridedSliceOfExtract final
4214 :
public OpRewritePattern<InsertStridedSliceOp> {
4218 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4219 PatternRewriter &rewriter)
const override {
4220 auto extractStridedSliceOp =
4221 insertStridedSliceOp.getValueToStore()
4222 .getDefiningOp<vector::ExtractStridedSliceOp>();
4224 if (!extractStridedSliceOp)
4227 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4231 if (extractStridedSliceOp.getStrides() !=
4232 insertStridedSliceOp.getStrides() ||
4233 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4236 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4243class InsertStridedSliceConstantFolder final
4244 :
public OpRewritePattern<InsertStridedSliceOp> {
4250 static constexpr int64_t vectorSizeFoldThreshold = 256;
4252 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4253 PatternRewriter &rewriter)
const override {
4257 Attribute vectorDestCst;
4261 VectorType destTy = destVector.getType();
4262 if (destTy.isScalable())
4266 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4267 !destVector.hasOneUse())
4271 Attribute sourceCst;
4281 if (op.hasNonUnitStrides())
4284 VectorType sliceVecTy = sourceValue.getType();
4285 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4286 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4287 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
4288 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
4296 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4297 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4298 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4299 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4300 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4301 MutableArrayRef<int64_t> currSlicePosition(
4302 currDestPosition.begin() + rankDifference, currDestPosition.end());
4303 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4306 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
4307 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
4308 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4309 "Invalid slice element");
4310 newValues[linearizedPosition] = *sliceValuesIt;
4323void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4324 RewritePatternSet &results, MLIRContext *context) {
4325 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4326 InsertStridedSliceConstantFolder>(context);
4329OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4330 if (getSourceVectorType() == getDestVectorType())
4331 return getValueToStore();
4340void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4341 Value
lhs, Value
rhs, Value acc) {
4346void OuterProductOp::print(OpAsmPrinter &p) {
4347 p <<
" " << getLhs() <<
", " << getRhs();
4349 p <<
", " << getAcc();
4352 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4355ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4356 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4363 if (operandsInfo.size() < 2)
4365 "expected at least 2 operands");
4366 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4367 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4370 "expected vector type for operand #1");
4374 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4375 vRHS.getScalableDims()[0]};
4376 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4377 vLHS.getElementType(), scalableDimsRes);
4380 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4381 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4385 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4386 result.attributes.append(
4387 OuterProductOp::getKindAttrName(
result.name),
4388 CombiningKindAttr::get(
result.getContext(),
4389 OuterProductOp::getDefaultKind()));
4395 (operandsInfo.size() > 2 &&
4400LogicalResult OuterProductOp::verify() {
4401 Type tRHS = getOperandTypeRHS();
4402 VectorType vLHS = getOperandVectorTypeLHS(),
4403 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4404 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4406 if (vLHS.getRank() != 1)
4407 return emitOpError(
"expected 1-d vector for operand #1");
4411 if (vRHS.getRank() != 1)
4412 return emitOpError(
"expected 1-d vector for operand #2");
4413 if (vRES.getRank() != 2)
4415 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4416 return emitOpError(
"expected #1 operand dim to match result dim #1");
4417 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4418 return emitOpError(
"expected #2 operand dim to match result dim #2");
4419 if (vLHS.isScalable() && !vRHS.isScalable()) {
4423 "expected either both or only #2 operand dim to be scalable");
4427 if (vRES.getRank() != 1)
4429 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4430 return emitOpError(
"expected #1 operand dim to match result dim #1");
4433 if (vACC && vACC != vRES)
4434 return emitOpError(
"expected operand #3 of same type as result type");
4436 if (!getKindAttr()) {
4437 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4438 "'vector.kind<add>')");
4443 return emitOpError(
"unsupported outerproduct type");
4452Type OuterProductOp::getExpectedMaskType() {
4453 auto vecType = this->getResultVectorType();
4454 return VectorType::get(vecType.getShape(),
4455 IntegerType::get(vecType.getContext(), 1),
4456 vecType.getScalableDims());
4470 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4472 shape.reserve(vectorType.getRank());
4474 for (
unsigned e = offsets.size(); idx < e; ++idx)
4475 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4476 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4477 shape.push_back(vectorType.getShape()[idx]);
4479 return VectorType::get(
shape, vectorType.getElementType(),
4480 vectorType.getScalableDims());
4483void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4484 Value source, ArrayRef<int64_t> offsets,
4485 ArrayRef<int64_t> sizes,
4486 ArrayRef<int64_t> strides) {
4487 result.addOperands(source);
4493 offsetsAttr, sizesAttr, stridesAttr));
4494 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4496 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4498 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4502LogicalResult ExtractStridedSliceOp::verify() {
4503 auto type = getSourceVectorType();
4504 auto offsets = getOffsetsAttr();
4505 auto sizes = getSizesAttr();
4506 auto strides = getStridesAttr();
4507 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4509 "expected offsets, sizes and strides attributes of same size");
4511 auto shape = type.getShape();
4512 auto offName = getOffsetsAttrName();
4513 auto sizesName = getSizesAttrName();
4514 auto stridesName = getStridesAttrName();
4530 shape, offName, sizesName,
4535 offsets, sizes, strides);
4536 if (getResult().
getType() != resultType)
4537 return emitOpError(
"expected result type to be ") << resultType;
4539 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4540 if (type.getScalableDims()[idx]) {
4541 auto inputDim = type.getShape()[idx];
4542 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4543 if (inputDim != inputSize)
4546 << (
" to match the corresponding base size from the input "
4548 << inputSize << (
" vs ") << inputDim << (
")");
4561 auto getElement = [](
ArrayAttr array,
int idx) {
4562 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4564 ArrayAttr extractOffsets = op.getOffsets();
4567 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4569 if (op.getSourceVectorType().getRank() !=
4570 insertOp.getSourceVectorType().getRank())
4572 ArrayAttr insertOffsets = insertOp.getOffsets();
4573 ArrayAttr insertStrides = insertOp.getStrides();
4576 if (extractOffsets.size() > insertOffsets.size())
4578 bool patialoverlap =
false;
4579 bool disjoint =
false;
4581 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4582 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4584 int64_t start = getElement(insertOffsets, dim);
4585 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4586 int64_t offset = getElement(extractOffsets, dim);
4587 int64_t size = getElement(extractSizes, dim);
4589 if (start <= offset && offset < end) {
4592 if (offset + size > end)
4593 patialoverlap =
true;
4594 offsetDiffs.push_back(offset - start);
4601 if (!disjoint && !patialoverlap) {
4602 op.setOperand(insertOp.getValueToStore());
4605 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4611 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4626 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4631 if (op.hasNonUnitStrides())
4634 VectorType sourceVecTy = op.getSourceVectorType();
4638 VectorType sliceVecTy = op.getType();
4640 int64_t rank = sliceVecTy.getRank();
4652 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4654 sliceValues.reserve(sliceVecTy.getNumElements());
4658 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4660 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4661 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4663 assert(
static_cast<int64_t>(sliceValues.size()) ==
4664 sliceVecTy.getNumElements() &&
4665 "Invalid number of slice elements");
4669OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4670 if (getSourceVectorType() == getResult().
getType())
4677 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4684void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4706class StridedSliceFolder final
4707 :
public OpRewritePattern<ExtractStridedSliceOp> {
4709 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4711 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4712 PatternRewriter &rewriter)
const override {
4713 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4717 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4720 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4721 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4722 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4723 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4725 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4726 SmallVector<int64_t> combinedOffsets(newRank, 0);
4727 SmallVector<int64_t> combinedSizes(newRank);
4728 ArrayRef<int64_t> firstSourceShape =
4729 firstOp.getSourceVectorType().getShape();
4730 for (
unsigned i = 0; i < newRank; ++i) {
4731 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4732 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4733 combinedOffsets[i] = off1 + off2;
4735 if (i < secondSizes.size()) {
4736 combinedSizes[i] = secondSizes[i];
4737 }
else if (i < firstSizes.size()) {
4738 combinedSizes[i] = firstSizes[i];
4740 combinedSizes[i] = firstSourceShape[i];
4744 SmallVector<int64_t> combinedStrides(newRank, 1);
4746 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4764class StridedSliceCreateMaskFolder final
4765 :
public OpRewritePattern<ExtractStridedSliceOp> {
4769 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4770 PatternRewriter &rewriter)
const override {
4771 Location loc = extractStridedSliceOp.getLoc();
4775 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4779 if (extractStridedSliceOp.hasNonUnitStrides())
4782 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4784 SmallVector<int64_t> sliceOffsets;
4787 SmallVector<int64_t> sliceSizes;
4791 SmallVector<Value> sliceMaskDimSizes;
4792 sliceMaskDimSizes.reserve(maskDimSizes.size());
4796 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4797 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4801 IntegerAttr offsetAttr =
4803 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4804 Value sliceMaskDimSize =
4805 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4806 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4811 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4815 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4823class StridedSliceConstantMaskFolder final
4824 :
public OpRewritePattern<ExtractStridedSliceOp> {
4828 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4829 PatternRewriter &rewriter)
const override {
4832 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4833 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4834 if (!constantMaskOp)
4837 if (extractStridedSliceOp.hasNonUnitStrides())
4840 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4842 SmallVector<int64_t> sliceOffsets;
4845 SmallVector<int64_t> sliceSizes;
4849 SmallVector<int64_t> sliceMaskDimSizes;
4850 sliceMaskDimSizes.reserve(maskDimSizes.size());
4851 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4852 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4853 int64_t sliceMaskDimSize = std::max(
4854 static_cast<int64_t
>(0),
4855 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4856 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4859 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4860 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4861 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4864 if (llvm::is_contained(sliceMaskDimSizes, 0))
4865 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4870 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4878class StridedSliceBroadcast final
4879 :
public OpRewritePattern<ExtractStridedSliceOp> {
4883 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4884 PatternRewriter &rewriter)
const override {
4890 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4891 auto dstVecType = llvm::cast<VectorType>(op.getType());
4892 unsigned dstRank = dstVecType.getRank();
4893 unsigned rankDiff = dstRank - srcRank;
4897 bool needsSlice =
false;
4898 for (
unsigned i = 0; i < srcRank; i++) {
4899 if (srcVecType.getDimSize(i) != 1 &&
4900 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4907 SmallVector<int64_t> offsets =
4909 SmallVector<int64_t> sizes =
4911 for (
unsigned i = 0; i < srcRank; i++) {
4912 if (srcVecType.getDimSize(i) == 1) {
4920 source = ExtractStridedSliceOp::create(
4921 rewriter, op->getLoc(), source, offsets, sizes,
4930class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4934 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4935 PatternRewriter &rewriter)
const override {
4937 Value splat = getScalarSplatSource(op.getSource());
4961class ContiguousExtractStridedSliceToExtract final
4962 :
public OpRewritePattern<ExtractStridedSliceOp> {
4966 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4967 PatternRewriter &rewriter)
const override {
4968 if (op.hasNonUnitStrides())
4970 Value source = op.getOperand();
4971 auto sourceType = cast<VectorType>(source.
getType());
4972 if (sourceType.isScalable() || sourceType.getRank() == 0)
4981 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4982 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4989 if (numOffsets == 0)
4994 if (numOffsets == sourceType.getRank() &&
4995 static_cast<int>(sizes.size()) == sourceType.getRank())
4999 for (
int i = 0; i < numOffsets; ++i) {
5007 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
5008 sizes[numOffsets] == 1) {
5013 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5014 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5023void ExtractStridedSliceOp::getCanonicalizationPatterns(
5024 RewritePatternSet &results, MLIRContext *context) {
5027 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5028 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5029 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5038void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5039 VectorType vectorType, Value source,
5041 AffineMapAttr permutationMapAttr,
5044 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5046 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5047 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
5048 *padding, Value(), inBoundsAttr);
5052void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5053 VectorType vectorType, Value source,
5055 AffineMap permutationMap,
5056 std::optional<ArrayRef<bool>> inBounds) {
5057 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5058 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5061 SmallVector<bool>(vectorType.getRank(),
false));
5062 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5064 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5065 build(builder,
result, vectorType, source,
indices, *padding,
5066 permutationMapAttr, inBoundsAttr);
5070void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5071 VectorType vectorType, Value source,
5073 std::optional<ArrayRef<bool>> inBounds) {
5075 llvm::cast<ShapedType>(source.
getType()), vectorType);
5076 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5077 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5080 SmallVector<bool>(vectorType.getRank(),
false));
5081 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5083 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5084 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
5086 Value(), inBoundsAttr);
5089template <
typename EmitFun>
5093 for (
auto expr : permutationMap.
getResults()) {
5094 auto dim = dyn_cast<AffineDimExpr>(expr);
5095 auto zero = dyn_cast<AffineConstantExpr>(expr);
5097 if (zero.getValue() != 0) {
5099 "requires a projected permutation_map (at most one dim or the zero "
5100 "constant can appear in each result)");
5105 return emitOpError(
"requires a projected permutation_map (at most one "
5106 "dim or the zero constant can appear in each result)");
5108 if (seen[dim.getPosition()]) {
5110 "requires a permutation_map that is a permutation (found one dim "
5111 "used more than once)");
5113 seen[dim.getPosition()] =
true;
5120 VectorType vectorType, VectorType maskType,
5121 VectorType inferredMaskType,
AffineMap permutationMap,
5123 if (op->hasAttr(
"masked")) {
5124 return op->emitOpError(
"masked attribute has been removed. "
5125 "Use in_bounds instead.");
5128 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5129 return op->emitOpError(
5130 "requires source to be a memref or ranked tensor type");
5132 auto elementType = shapedType.getElementType();
5134 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5136 unsigned sourceVecSize =
5138 vectorElementType.getShape().back();
5139 unsigned resultVecSize =
5141 vectorType.getShape().back();
5142 if (resultVecSize % sourceVecSize != 0)
5143 return op->emitOpError(
5144 "requires the bitwidth of the minor 1-D vector to be an integral "
5145 "multiple of the bitwidth of the minor 1-D vector of the source");
5147 unsigned sourceVecEltRank = vectorElementType.getRank();
5148 unsigned resultVecRank = vectorType.getRank();
5149 if (sourceVecEltRank > resultVecRank)
5150 return op->emitOpError(
5151 "requires source vector element and vector result ranks to match.");
5152 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5155 return op->emitOpError(
"requires a permutation_map with result dims of "
5156 "the same rank as the vector type");
5159 return op->emitOpError(
"does not support masks with vector element type");
5162 unsigned minorSize =
5163 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5164 unsigned resultVecSize =
5167 return op->emitOpError(
5168 "requires the bitwidth of the minor 1-D vector to be an integral "
5169 "multiple of the bitwidth of the source element type");
5173 return op->emitOpError(
"requires a permutation_map with result dims of "
5174 "the same rank as the vector type");
5178 return op->emitOpError(
"requires permutation_map without symbols");
5180 if (permutationMap.
getNumInputs() != shapedType.getRank())
5181 return op->emitOpError(
"requires a permutation_map with input dims of the "
5182 "same rank as the source type");
5184 if (maskType && maskType != inferredMaskType)
5185 return op->emitOpError(
"inferred mask type (")
5186 << inferredMaskType <<
") and mask operand type (" << maskType
5190 return op->emitOpError(
"expects the in_bounds attr of same rank "
5191 "as permutation_map results: ")
5192 << AffineMapAttr::get(permutationMap)
5193 <<
" vs inBounds of size: " << inBounds.size();
5200 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5201 if (op.getPermutationMap().isMinorIdentity())
5202 elidedAttrs.push_back(op.getPermutationMapAttrName());
5204 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
5205 elidedAttrs.push_back(op.getInBoundsAttrName());
5209void TransferReadOp::print(OpAsmPrinter &p) {
5212 p <<
", " << getMask();
5219 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
5221 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
5226 if (maskShape.empty())
5227 maskShape.push_back(1);
5232 return VectorType::get(maskShape, i1Type, scalableDims);
5249 if (hasMask.succeeded()) {
5256 if (types.size() != 2)
5257 return parser.
emitError(typesLoc,
"requires two types");
5259 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5260 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5261 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5262 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5264 return parser.
emitError(typesLoc,
"requires vector type");
5265 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
5269 if (shapedType.getRank() <
5272 "expected a custom permutation_map when "
5273 "rank(source) != rank(destination)");
5275 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5277 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5279 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
5280 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5281 if (!inBoundsAttr) {
5282 result.addAttribute(inBoundsAttrName,
5291 if (hasMask.succeeded()) {
5292 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5294 maskInfo.
location,
"does not support masks with vector element type");
5297 "expected the same rank for the vector and the "
5298 "results of the permutation map");
5306 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5308 {1, static_cast<int32_t>(indexInfo.size()), 1,
5309 static_cast<int32_t>(hasMask.succeeded())}));
5313LogicalResult TransferReadOp::verify() {
5315 ShapedType shapedType = getShapedType();
5317 VectorType maskType = getMaskType();
5318 auto paddingType = getPadding().getType();
5319 auto permutationMap = getPermutationMap();
5320 VectorType inferredMaskType =
5323 auto sourceElementType = shapedType.getElementType();
5325 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5326 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5329 shapedType, vectorType, maskType,
5330 inferredMaskType, permutationMap, getInBounds())))
5333 if (
auto sourceVectorElementType =
5334 llvm::dyn_cast<VectorType>(sourceElementType)) {
5337 if (sourceVectorElementType != paddingType)
5339 "requires source element type and padding type to match.");
5343 if (!VectorType::isValidElementType(paddingType))
5344 return emitOpError(
"requires valid padding vector elemental type");
5347 if (paddingType != sourceElementType)
5349 "requires formal padding and source of the same elemental type");
5360Type TransferReadOp::getExpectedMaskType() {
5367VectorType TransferReadOp::getVectorType() {
5368 return cast<VectorType>(getVector().
getType());
5371template <
typename TransferOp>
5375 if (op.getShapedType().isDynamicDim(indicesIdx))
5379 if (!cstOp.has_value())
5382 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5383 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5385 return cstOp.value() + vectorSize <= sourceSize;
5388template <
typename TransferOp>
5392 if (op.getTransferRank() == 0)
5395 bool changed =
false;
5397 newInBounds.reserve(op.getTransferRank());
5402 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5404 if (op.isDimInBounds(i)) {
5405 newInBounds.push_back(
true);
5410 bool inBounds =
false;
5411 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5414 dimExpr.getPosition());
5415 nonBcastDims.push_back(i);
5418 newInBounds.push_back(inBounds);
5420 changed |= inBounds;
5426 bool allNonBcastDimsInBounds = llvm::all_of(
5427 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5428 if (allNonBcastDimsInBounds) {
5430 changed |= !newInBounds[idx];
5431 newInBounds[idx] =
true;
5439 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5443template <
typename TransferOp>
5445 auto mask = op.getMask();
5452 op.getMaskMutable().clear();
5460template <
typename TransferOp>
5462 VectorType vecType = op.getVectorType();
5463 if (vecType.getRank() != 1 || vecType.getShape()[0] != 1 ||
5464 vecType.isScalable())
5471 int64_t srcRank = op.getShapedType().getRank();
5477 op.setPermutationMapAttr(AffineMapAttr::get(minorIdentity));
5491static Value foldRAW(TransferReadOp readOp) {
5492 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5494 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5497 return defWrite.getVector();
5499 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5500 cast<VectorTransferOpInterface>(readOp.getOperation())))
5502 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5507OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5508 if (Value vec = foldRAW(*
this))
5521 return OpFoldResult();
5524std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5528void TransferReadOp::getEffects(
5529 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5531 if (llvm::isa<MemRefType>(getShapedType()))
5532 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5533 SideEffects::DefaultResource::get());
5537 if (hasPureTensorSemantics())
5544static AffineMap inverseWithUnusedDims(AffineMap map) {
5546 "expected a projected permutation map");
5551 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5581struct TransferReadAfterWriteToBroadcast
5582 :
public OpRewritePattern<TransferReadOp> {
5585 LogicalResult matchAndRewrite(TransferReadOp readOp,
5586 PatternRewriter &rewriter)
const override {
5587 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5591 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5595 if (readOp.getMask() || defWrite.getMask())
5598 if (readOp.getIndices() != defWrite.getIndices())
5601 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5605 if (readOp.getTransferChunkAccessed() !=
5606 defWrite.getTransferChunkAccessed())
5613 AffineMap readMap = readOp.getPermutationMap();
5614 AffineMap writeMap = defWrite.getPermutationMap();
5615 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5616 AffineMap composedMap = readMap.
compose(invWriteMap);
5630 int64_t numBroadcastedDims = broadcastedDims.size();
5631 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5633 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5634 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5635 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5636 invPerm[effectiveDim] = idx;
5641 VectorType readVecTy = readOp.getVectorType();
5643 auto broadcastedVecTy =
5645 readVecTy.getElementType(),
5648 Value vec = defWrite.getVector();
5649 Location loc = readOp.getLoc();
5650 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5657void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5658 MLIRContext *context) {
5659 results.
add<TransferReadAfterWriteToBroadcast>(context);
5662FailureOr<std::optional<SmallVector<Value>>>
5663TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5664 if (!hasPureBufferSemantics())
5675void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5677 AffineMapAttr permutationMapAttr,
5680 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5681 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5682 mask, inBoundsAttr);
5686void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5688 AffineMapAttr permutationMapAttr,
5690 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5691 Value(), inBoundsAttr);
5696void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5698 AffineMap permutationMap,
5699 std::optional<ArrayRef<bool>> inBounds) {
5700 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5702 (inBounds && !inBounds.value().empty())
5705 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5706 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5707 Value(), inBoundsAttr);
5712void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5714 std::optional<ArrayRef<bool>> inBounds) {
5715 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5717 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5718 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5721ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5722 OperationState &
result) {
5725 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5726 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5727 SmallVector<Type, 2> types;
5728 OpAsmParser::UnresolvedOperand maskInfo;
5734 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5739 if (types.size() != 2)
5740 return parser.
emitError(typesLoc,
"requires two types");
5742 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5744 return parser.
emitError(typesLoc,
"requires vector type");
5745 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5746 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5747 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5748 auto permMapAttrName =
5749 TransferWriteOp::getPermutationMapAttrName(
result.name);
5750 auto permMapAttr =
result.attributes.get(permMapAttrName);
5753 if (shapedType.getRank() <
5756 "expected a custom permutation_map when "
5757 "rank(source) != rank(destination)");
5759 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5761 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5763 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5764 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5765 if (!inBoundsAttr) {
5766 result.addAttribute(inBoundsAttrName,
5774 if (hasMask.succeeded()) {
5775 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5777 maskInfo.
location,
"does not support masks with vector element type");
5780 "expected the same rank for the vector and the "
5781 "results of the permutation map");
5787 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5789 {1, 1, static_cast<int32_t>(indexInfo.size()),
5790 static_cast<int32_t>(hasMask.succeeded())}));
5791 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5795void TransferWriteOp::print(OpAsmPrinter &p) {
5798 p <<
", " << getMask();
5803LogicalResult TransferWriteOp::verify() {
5805 ShapedType shapedType = getShapedType();
5807 VectorType maskType = getMaskType();
5808 auto permutationMap = getPermutationMap();
5809 VectorType inferredMaskType =
5813 if (llvm::size(
getIndices()) != shapedType.getRank())
5814 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5818 if (hasBroadcastDim())
5819 return emitOpError(
"should not have broadcast dimensions");
5822 shapedType, vectorType, maskType,
5823 inferredMaskType, permutationMap, getInBounds())))
5836Type TransferWriteOp::getExpectedMaskType() {
5843Value TransferWriteOp::getVector() {
return getOperand(0); }
5844VectorType TransferWriteOp::getVectorType() {
5845 return cast<VectorType>(getValueToStore().
getType());
5868static LogicalResult foldReadInitWrite(TransferWriteOp write,
5869 ArrayRef<Attribute>,
5870 SmallVectorImpl<OpFoldResult> &results) {
5872 if (write.getTransferRank() == 0)
5874 auto rankedTensorType =
5875 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5877 if (!rankedTensorType)
5880 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5884 if (read.getTransferRank() == 0)
5887 if (!read.getPermutationMap().isMinorIdentity() ||
5888 !write.getPermutationMap().isMinorIdentity())
5891 if (read.getTransferRank() != write.getTransferRank())
5894 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5897 if (read.getMask() || write.getMask())
5900 if (read.getBase().getType() != rankedTensorType)
5903 if (read.getVectorType() != write.getVectorType())
5906 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5909 auto isNotConstantZero = [](Value v) {
5911 return !cstOp.has_value() || cstOp.value() != 0;
5913 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5914 llvm::any_of(write.getIndices(), isNotConstantZero))
5917 results.push_back(read.getBase());
5921static bool checkSameValueWAR(vector::TransferReadOp read,
5922 vector::TransferWriteOp write) {
5923 return read.getBase() == write.getBase() &&
5924 read.getIndices() == write.getIndices() &&
5925 read.getPermutationMap() == write.getPermutationMap() &&
5926 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5943static LogicalResult foldWAR(TransferWriteOp write,
5944 SmallVectorImpl<OpFoldResult> &results) {
5945 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5947 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5951 if (!checkSameValueWAR(read, write))
5953 results.push_back(read.getBase());
5957LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5958 SmallVectorImpl<OpFoldResult> &results) {
5959 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5961 if (succeeded(foldWAR(*
this, results)))
5975std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5979void TransferWriteOp::getEffects(
5980 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5982 if (llvm::isa<MemRefType>(getShapedType()))
5983 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5984 SideEffects::DefaultResource::get());
5988 if (hasPureTensorSemantics())
6018class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
6021 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
6022 PatternRewriter &rewriter)
const override {
6023 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
6025 vector::TransferWriteOp writeToModify = writeOp;
6027 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
6031 writeToModify.getBaseMutable().assign(defWrite.getBase());
6036 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6037 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6041 if (!defWrite->hasOneUse())
6043 writeToModify = defWrite;
6044 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6073struct SwapExtractSliceOfTransferWrite
6074 :
public OpRewritePattern<tensor::InsertSliceOp> {
6078 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6079 PatternRewriter &rewriter)
const override {
6080 if (!insertOp.hasUnitStride())
6083 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6084 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6086 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6087 if (!transferOp || !transferOp->hasOneUse())
6092 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6094 "use-def chain is rank-reducing");
6098 if (!extractOp.hasZeroOffset()) {
6100 "ExtractSliceOp has non-zero offset");
6104 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6105 return getConstantIntValue(value) == static_cast<int64_t>(0);
6108 "TranferWriteOp has non-zero offset");
6112 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6114 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
6117 for (
auto [insertSize, extractSize] :
6118 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6121 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
6126 assert(transferOp.getVectorType().hasStaticShape() &&
6127 "expected vector to have a static shape");
6128 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
6130 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6131 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
6133 insertOp,
"TransferWriteOp may not write the full tensor.");
6138 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
6139 auto newExtractOp = tensor::ExtractSliceOp::create(
6140 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6141 insertOp.getDest(), insertOp.getMixedOffsets(),
6142 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6143 auto newTransferWriteOp = TransferWriteOp::create(
6144 rewriter, transferOp.getLoc(), transferOp.getVector(),
6145 newExtractOp.getResult(), transferOp.getIndices(),
6146 transferOp.getPermutationMapAttr(),
6149 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6157void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6158 MLIRContext *context) {
6159 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6162FailureOr<std::optional<SmallVector<Value>>>
6163TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6164 if (!hasPureBufferSemantics())
6174static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6176 MemRefType memRefTy) {
6179 if (!vecTy.isScalable() &&
6180 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6183 if (!memRefTy.isLastDimUnitStride())
6184 return op->
emitOpError(
"most minor memref dim must have unit stride");
6188LogicalResult vector::LoadOp::verify() {
6192 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
6195 if (memRefTy.getRank() < resVecTy.getRank())
6197 "destination memref has lower rank than the result vector");
6200 Type memElemTy = memRefTy.getElementType();
6201 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6202 if (memVecTy != resVecTy)
6203 return emitOpError(
"base memref and result vector types should match");
6204 memElemTy = memVecTy.getElementType();
6207 if (resVecTy.getElementType() != memElemTy)
6208 return emitOpError(
"base and result element types should match");
6209 if (llvm::size(
getIndices()) != memRefTy.getRank())
6210 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6214OpFoldResult LoadOp::fold(FoldAdaptor) {
6217 return OpFoldResult();
6220std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6224FailureOr<std::optional<SmallVector<Value>>>
6225LoadOp::bubbleDownCasts(OpBuilder &builder) {
6234LogicalResult vector::StoreOp::verify() {
6238 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
6241 if (memRefTy.getRank() < valueVecTy.getRank())
6242 return emitOpError(
"source memref has lower rank than the vector to store");
6245 Type memElemTy = memRefTy.getElementType();
6246 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6247 if (memVecTy != valueVecTy)
6249 "base memref and valueToStore vector types should match");
6250 memElemTy = memVecTy.getElementType();
6253 if (valueVecTy.getElementType() != memElemTy)
6254 return emitOpError(
"base and valueToStore element type should match");
6255 if (llvm::size(
getIndices()) != memRefTy.getRank())
6256 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6260LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6261 SmallVectorImpl<OpFoldResult> &results) {
6265std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6269FailureOr<std::optional<SmallVector<Value>>>
6270StoreOp::bubbleDownCasts(OpBuilder &builder) {
6279LogicalResult MaskedLoadOp::verify() {
6280 VectorType maskVType = getMaskVectorType();
6281 VectorType passVType = getPassThruVectorType();
6288 if (llvm::size(
getIndices()) != memType.getRank())
6289 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6290 if (resVType.getShape() != maskVType.getShape())
6291 return emitOpError(
"expected result shape to match mask shape");
6292 if (resVType != passVType)
6293 return emitOpError(
"expected pass_thru of same type as result type");
6298class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
6301 LogicalResult matchAndRewrite(MaskedLoadOp
load,
6302 PatternRewriter &rewriter)
const override {
6314 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
6319void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6320 MLIRContext *context) {
6321 results.
add<MaskedLoadFolder>(context);
6324OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6327 return OpFoldResult();
6330FailureOr<std::optional<SmallVector<Value>>>
6331MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6340LogicalResult MaskedStoreOp::verify() {
6341 VectorType maskVType = getMaskVectorType();
6348 if (llvm::size(
getIndices()) != memType.getRank())
6349 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6350 if (valueVType.getShape() != maskVType.getShape())
6351 return emitOpError(
"expected valueToStore shape to match mask shape");
6356class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6359 LogicalResult matchAndRewrite(MaskedStoreOp store,
6360 PatternRewriter &rewriter)
const override {
6364 store, store.getValueToStore(), store.getBase(), store.getIndices());
6372 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6377void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6378 MLIRContext *context) {
6379 results.
add<MaskedStoreFolder>(context);
6382LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6383 SmallVectorImpl<OpFoldResult> &results) {
6387FailureOr<std::optional<SmallVector<Value>>>
6388MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6397LogicalResult GatherOp::verify() {
6398 VectorType indVType = getIndexVectorType();
6399 VectorType maskVType = getMaskVectorType();
6401 ShapedType baseType = getBaseType();
6403 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6404 return emitOpError(
"requires base to be a memref or ranked tensor type");
6409 if (llvm::size(getOffsets()) != baseType.getRank())
6410 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6411 if (resVType.getShape() != indVType.getShape())
6412 return emitOpError(
"expected result dim to match indices dim");
6413 if (resVType.getShape() != maskVType.getShape())
6414 return emitOpError(
"expected result dim to match mask dim");
6415 if (resVType != getPassThruVectorType())
6416 return emitOpError(
"expected pass_thru of same type as result type");
6417 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6419 "alignment is only supported for memref bases, not tensor bases");
6428Type GatherOp::getExpectedMaskType() {
6429 auto vecType = this->getIndexVectorType();
6430 return VectorType::get(vecType.getShape(),
6431 IntegerType::get(vecType.getContext(), 1),
6432 vecType.getScalableDims());
6435std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6440static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6441 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6442 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6448 DenseIntElementsAttr elements;
6453 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6457class GatherFolder final :
public OpRewritePattern<GatherOp> {
6460 LogicalResult matchAndRewrite(GatherOp gather,
6461 PatternRewriter &rewriter)
const override {
6466 rewriter.
replaceOp(gather, gather.getPassThru());
6471 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6477class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6480 LogicalResult matchAndRewrite(GatherOp op,
6481 PatternRewriter &rewriter)
const override {
6482 if (!isa<MemRefType>(op.getBase().getType()))
6485 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6489 op.getOffsets(), op.getMask(),
6496void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6497 MLIRContext *context) {
6498 results.
add<GatherFolder, FoldContiguousGather>(context);
6501FailureOr<std::optional<SmallVector<Value>>>
6502GatherOp::bubbleDownCasts(OpBuilder &builder) {
6511LogicalResult ScatterOp::verify() {
6512 VectorType indVType = getIndexVectorType();
6513 VectorType maskVType = getMaskVectorType();
6515 ShapedType baseType = getBaseType();
6517 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6518 return emitOpError(
"requires base to be a memref or ranked tensor type");
6523 if (llvm::size(getOffsets()) != baseType.getRank())
6524 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6525 if (valueVType.getShape() != indVType.getShape())
6526 return emitOpError(
"expected valueToStore dim to match indices dim");
6527 if (valueVType.getShape() != maskVType.getShape())
6528 return emitOpError(
"expected valueToStore dim to match mask dim");
6529 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6531 "alignment is only supported for memref bases, not tensor bases");
6536class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6539 LogicalResult matchAndRewrite(ScatterOp scatter,
6540 PatternRewriter &rewriter)
const override {
6541 ShapedType baseType = scatter.getBaseType();
6542 bool isMemRef = isa<MemRefType>(baseType);
6543 if (!isMemRef && !isa<RankedTensorType>(baseType))
6556 rewriter.
replaceOp(scatter, scatter.getBase());
6561 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6567class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6570 LogicalResult matchAndRewrite(ScatterOp op,
6571 PatternRewriter &rewriter)
const override {
6574 if (!isa<MemRefType>(op.getBase().getType()))
6577 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6581 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6587void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6588 MLIRContext *context) {
6589 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6592FailureOr<std::optional<SmallVector<Value>>>
6593ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6602LogicalResult ExpandLoadOp::verify() {
6603 VectorType maskVType = getMaskVectorType();
6604 VectorType passVType = getPassThruVectorType();
6611 if (llvm::size(
getIndices()) != memType.getRank())
6612 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6613 if (resVType.getShape() != maskVType.getShape())
6614 return emitOpError(
"expected result shape to match mask shape");
6615 if (resVType != passVType)
6616 return emitOpError(
"expected pass_thru of same type as result type");
6621class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6624 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6625 PatternRewriter &rewriter)
const override {
6629 expand, expand.getType(), expand.getBase(), expand.getIndices());
6632 rewriter.
replaceOp(expand, expand.getPassThru());
6637 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6642void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6643 MLIRContext *context) {
6644 results.
add<ExpandLoadFolder>(context);
6647FailureOr<std::optional<SmallVector<Value>>>
6648ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6657LogicalResult CompressStoreOp::verify() {
6658 VectorType maskVType = getMaskVectorType();
6665 if (llvm::size(
getIndices()) != memType.getRank())
6666 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6667 if (valueVType.getShape() != maskVType.getShape())
6668 return emitOpError(
"expected valueToStore shape to match mask shape");
6673class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6676 LogicalResult matchAndRewrite(CompressStoreOp compress,
6677 PatternRewriter &rewriter)
const override {
6681 compress, compress.getValueToStore(), compress.getBase(),
6682 compress.getIndices());
6690 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6695void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6696 MLIRContext *context) {
6697 results.
add<CompressStoreFolder>(context);
6700FailureOr<std::optional<SmallVector<Value>>>
6701CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6710void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6712 setResultRanges(getResult(), argRanges.front());
6715std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6716 return llvm::to_vector<4>(getResultVectorType().
getShape());
6719LogicalResult ShapeCastOp::verify() {
6721 VectorType sourceType = getSourceVectorType();
6722 VectorType resultType = getResultVectorType();
6730 int64_t sourceNElms = sourceType.getNumElements();
6731 int64_t resultNElms = resultType.getNumElements();
6732 if (sourceNElms != resultNElms) {
6733 return emitOpError() <<
"has different number of elements at source ("
6734 << sourceNElms <<
") and result (" << resultNElms
6739 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6740 int64_t resultNScalableDims = resultType.getNumScalableDims();
6741 if (sourceNScalableDims != resultNScalableDims)
6742 return emitOpError() <<
"has different number of scalable dims at source ("
6743 << sourceNScalableDims <<
") and result ("
6744 << resultNScalableDims <<
")";
6753static bool isOrderPreserving(TransposeOp transpose) {
6754 ArrayRef<int64_t> permutation = transpose.getPermutation();
6755 VectorType sourceType = transpose.getSourceVectorType();
6756 ArrayRef<int64_t> inShape = sourceType.getShape();
6757 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6758 auto isNonScalableUnitDim = [&](int64_t dim) {
6759 return inShape[dim] == 1 && !inDimIsScalable[dim];
6761 int64_t current = 0;
6762 for (
auto p : permutation) {
6763 if (!isNonScalableUnitDim(p)) {
6773OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6775 VectorType resultType =
getType();
6778 if (getSource().
getType() == resultType)
6782 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6783 setOperand(precedingShapeCast.getSource());
6788 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6789 if (isOrderPreserving(transpose)) {
6790 setOperand(transpose.getVector());
6798 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6799 if (bcastOp.getSourceType() == resultType)
6800 return bcastOp.getSource();
6804 if (
auto denseAttr =
6805 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6806 return denseAttr.reshape(
getType());
6822static VectorType trimTrailingOneDims(VectorType oldType) {
6823 ArrayRef<int64_t> oldShape = oldType.getShape();
6824 ArrayRef<int64_t> newShape = oldShape;
6826 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6827 ArrayRef<bool> newScalableDims = oldScalableDims;
6829 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6830 newShape = newShape.drop_back(1);
6831 newScalableDims = newScalableDims.drop_back(1);
6836 if (newShape.empty()) {
6837 newShape = oldShape.take_back();
6838 newScalableDims = oldScalableDims.take_back();
6841 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6856class ShapeCastCreateMaskFolderTrailingOneDim final
6857 :
public OpRewritePattern<ShapeCastOp> {
6861 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6862 PatternRewriter &rewriter)
const override {
6863 Value shapeOpSrc = shapeOp->getOperand(0);
6864 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6865 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6866 if (!createMaskOp && !constantMaskOp)
6869 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6870 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6872 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6873 if (newVecType != shapeOpResTy)
6876 auto numDimsToDrop =
6877 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6884 auto maskOperands = createMaskOp.getOperands();
6885 auto numMaskOperands = maskOperands.size();
6888 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6890 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6891 if (!constant || (constant.value() != 1))
6894 SmallVector<Value> newMaskOperands =
6895 maskOperands.drop_back(numDimsToDrop);
6902 if (constantMaskOp) {
6903 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6904 auto numMaskOperands = maskDimSizes.size();
6907 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6909 if (maskDimSizes[i] != 1)
6913 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6926int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6927 ArrayRef<int64_t> dstShape) {
6928 int stretchingFactor = 1;
6929 int numLeadingDims = dstShape.size() - srcShape.size();
6930 for (
int i = 0, e = srcShape.size(); i < e; i++) {
6931 int64_t dstDim = dstShape[numLeadingDims + i];
6932 if (srcShape[i] == 1 && dstDim != 1) {
6933 stretchingFactor *= dstDim;
6936 return stretchingFactor;
6940class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6944 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6945 PatternRewriter &rewriter)
const override {
6947 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6951 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6952 bool srcIsScalar = !srcVectorType;
6960 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6961 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6962 ArrayRef<int64_t> srcShape =
6963 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6964 ArrayRef<int64_t> broadcastShape =
6965 broadcastOp.getResultVectorType().getShape();
6969 BroadcastableToResult::Success) {
6977 if (srcVectorType.getNumElements() != 1) {
6978 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6979 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6986 broadcastOp.getSource());
7005class FoldShapeCastOfFromElements final :
public OpRewritePattern<ShapeCastOp> {
7009 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
7010 PatternRewriter &rewriter)
const override {
7011 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
7016 shapeCastOp, shapeCastOp.getResultVectorType(),
7017 fromElements.getElements());
7024void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
7025 MLIRContext *context) {
7026 results.
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
7027 FoldShapeCastOfFromElements>(context);
7034LogicalResult BitCastOp::verify() {
7035 auto sourceVectorType = getSourceVectorType();
7036 auto resultVectorType = getResultVectorType();
7038 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7039 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7040 return emitOpError(
"dimension size mismatch at: ") << i;
7043 DataLayout dataLayout = DataLayout::closest(*
this);
7044 auto sourceElementBits =
7046 auto resultElementBits =
7049 if (sourceVectorType.getRank() == 0) {
7050 if (sourceElementBits != resultElementBits)
7051 return emitOpError(
"source/result bitwidth of the 0-D vector element "
7052 "types must be equal");
7053 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
7054 resultElementBits * resultVectorType.getShape().back()) {
7056 "source/result bitwidth of the minor 1-D vectors must be equal");
7062OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7068 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7069 if (getResult().
getType() == otherOp.getSource().getType())
7070 return otherOp.getSource();
7072 setOperand(otherOp.getSource());
7076 Attribute sourceConstant = adaptor.getSource();
7077 if (!sourceConstant)
7080 Type srcElemType = getSourceVectorType().getElementType();
7081 Type dstElemType = getResultVectorType().getElementType();
7083 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7084 if (floatPack.isSplat()) {
7085 auto splat = floatPack.getSplatValue<FloatAttr>();
7088 if (srcElemType.
isF16() && dstElemType.
isF32()) {
7089 uint32_t bits =
static_cast<uint32_t
>(
7090 splat.getValue().bitcastToAPInt().getZExtValue());
7092 bits = (bits << 16) | (bits & 0xffff);
7093 APInt intBits(32, bits);
7094 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7100 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7101 if (intPack.isSplat()) {
7102 auto splat = intPack.getSplatValue<IntegerAttr>();
7104 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.
isIntOrFloat()) {
7109 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7110 APInt intBits = splat.getValue().zext(dstBitWidth);
7113 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7114 intBits = (intBits << srcBitWidth) | intBits;
7124std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
7125 return llvm::to_vector<4>(getResultVectorType().
getShape());
7132static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7133 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7134 SmallVector<int64_t, 8> res(memRefType.getShape());
7136 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7142void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
7144 result.addOperands(source);
7145 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
7146 VectorType vectorType =
7147 VectorType::get(extractShape(memRefType),
7149 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7150 memRefType.getMemorySpace()));
7153LogicalResult TypeCastOp::verify() {
7154 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
7155 if (!canonicalType.getLayout().isIdentity())
7156 return emitOpError(
"expects operand to be a memref with identity layout");
7157 if (!getResultMemRefType().getLayout().isIdentity())
7158 return emitOpError(
"expects result to be a memref with identity layout");
7159 if (getResultMemRefType().getMemorySpace() !=
7161 return emitOpError(
"expects result in same memory space");
7164 auto resultType = getResultMemRefType();
7168 "expects result and operand with same underlying scalar type: ")
7170 if (extractShape(sourceType) != extractShape(resultType))
7172 "expects concatenated result and operand shapes to be equal: ")
7181void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
7182 Value vector, ArrayRef<int64_t> permutation) {
7183 VectorType vt = llvm::cast<VectorType>(vector.
getType());
7184 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7185 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7186 for (
unsigned i = 0; i < permutation.size(); ++i) {
7187 transposedShape[i] = vt.getShape()[permutation[i]];
7188 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7191 result.addOperands(vector);
7192 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7193 transposedScalableDims));
7194 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
7198OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7201 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7202 return splat.reshape(getResultVectorType());
7219 if (getSourceVectorType() == getResultVectorType() &&
7220 isOrderPreserving(*
this))
7226LogicalResult vector::TransposeOp::verify() {
7227 VectorType vectorType = getSourceVectorType();
7228 VectorType resultType = getResultVectorType();
7229 int64_t rank = resultType.getRank();
7230 if (vectorType.getRank() != rank)
7231 return emitOpError(
"vector result rank mismatch: ") << rank;
7233 ArrayRef<int64_t> perm = getPermutation();
7234 int64_t size = perm.size();
7236 return emitOpError(
"transposition length mismatch: ") << size;
7237 SmallVector<bool, 8> seen(rank,
false);
7238 for (
const auto &ta : llvm::enumerate(perm)) {
7239 if (ta.value() < 0 || ta.value() >= rank)
7240 return emitOpError(
"transposition index out of range: ") << ta.value();
7241 if (seen[ta.value()])
7242 return emitOpError(
"duplicate position index: ") << ta.value();
7243 seen[ta.value()] =
true;
7244 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7245 return emitOpError(
"dimension size mismatch at: ") << ta.value();
7250std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7251 return llvm::to_vector<4>(getResultVectorType().
getShape());
7254void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7256 setResultRanges(getResult(), argRanges.front());
7262class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
7266 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7267 PatternRewriter &rewriter)
const override {
7269 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7270 ArrayRef<int64_t> permutation2) {
7271 SmallVector<int64_t, 4>
result;
7272 for (
auto index : permutation2)
7273 result.push_back(permutation1[index]);
7278 vector::TransposeOp parentTransposeOp =
7279 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7280 if (!parentTransposeOp)
7283 SmallVector<int64_t, 4> permutation = composePermutations(
7284 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7287 transposeOp, transposeOp.getResult().
getType(),
7288 parentTransposeOp.getVector(), permutation);
7294class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
7298 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7299 PatternRewriter &rewriter)
const override {
7300 Value splat = getScalarSplatSource(transposeOp.getVector());
7305 transposeOp, transposeOp.getResultVectorType(), splat);
7311class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
7315 LogicalResult matchAndRewrite(TransposeOp transpOp,
7316 PatternRewriter &rewriter)
const override {
7317 Value transposeSrc = transpOp.getVector();
7318 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
7319 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
7320 if (!createMaskOp && !constantMaskOp)
7325 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7328 auto maskOperands = createMaskOp.getOperands();
7329 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7333 transpOp, transpOp.getResultVectorType(), newOperands);
7338 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7342 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7348class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
7352 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7353 PatternRewriter &rewriter)
const override {
7355 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7358 if (!isOrderPreserving(transposeOp))
7361 VectorType resultType = transposeOp.getType();
7368 shapeCastOp.getSource());
7387class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
7390 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7391 PatternRewriter &rewriter)
const override {
7392 auto fromElementsOp =
7393 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7394 if (!fromElementsOp)
7397 VectorType srcTy = fromElementsOp.getDest().getType();
7398 VectorType dstTy = transposeOp.getType();
7400 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7401 int64_t rank = srcTy.getRank();
7404 SmallVector<int64_t> inversePerm(rank, 0);
7405 for (int64_t i = 0; i < rank; ++i)
7406 inversePerm[permutation[i]] = i;
7408 ArrayRef<int64_t> srcShape = srcTy.getShape();
7409 ArrayRef<int64_t> dstShape = dstTy.getShape();
7410 SmallVector<int64_t> srcIdx(rank, 0);
7411 SmallVector<int64_t> dstIdx(rank, 0);
7415 auto elementsOld = fromElementsOp.getElements();
7416 SmallVector<Value> elementsNew;
7417 int64_t dstNumElements = dstTy.getNumElements();
7418 elementsNew.reserve(dstNumElements);
7422 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7426 for (int64_t j = 0; j < rank; ++j)
7427 srcIdx[j] = dstIdx[inversePerm[j]];
7429 int64_t srcLin =
linearize(srcIdx, srcStrides);
7431 elementsNew.push_back(elementsOld[srcLin]);
7465class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7468 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7469 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7471 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7472 PatternRewriter &rewriter)
const override {
7478 "not preceded by a broadcast");
7481 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7482 VectorType outputType = transpose.getResultVectorType();
7485 bool inputIsScalar = !inputType;
7486 if (inputIsScalar) {
7492 ArrayRef<int64_t> permutation = transpose.getPermutation();
7493 ArrayRef<int64_t> inputShape = inputType.getShape();
7494 int64_t inputRank = inputType.getRank();
7495 int64_t outputRank = transpose.getType().getRank();
7496 int64_t deltaRank = outputRank - inputRank;
7499 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7500 bool notOne = inputShape[inputIndex] != 1;
7501 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7502 bool groupEndFound = notOne || prevNotOne;
7503 if (groupEndFound) {
7504 int high = inputIndex + deltaRank;
7508 for (
int i = low; i < high; ++i) {
7509 if (permutation[i] < low || permutation[i] >= high) {
7511 transpose,
"permutation not local to group");
7525 vector::BroadcastableToResult::Success &&
7526 "not broadcastable directly to transpose output");
7537void vector::TransposeOp::getCanonicalizationPatterns(
7538 RewritePatternSet &results, MLIRContext *context) {
7539 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7540 FoldTransposeSplat, FoldTransposeFromElements,
7541 FoldTransposeBroadcast>(context);
7548void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7550 assert(kind == ConstantMaskKind::AllTrue ||
7551 kind == ConstantMaskKind::AllFalse);
7552 build(builder,
result, type,
7553 kind == ConstantMaskKind::AllTrue
7555 : SmallVector<int64_t>(type.getRank(), 0));
7558LogicalResult ConstantMaskOp::verify() {
7559 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7561 if (resultType.getRank() == 0) {
7562 if (getMaskDimSizes().size() != 1)
7563 return emitError(
"array attr must have length 1 for 0-D vectors");
7564 auto dim = getMaskDimSizes()[0];
7565 if (dim != 0 && dim != 1)
7566 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7571 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7573 "must specify array attr of size equal vector result rank");
7576 auto resultShape = resultType.getShape();
7577 auto resultScalableDims = resultType.getScalableDims();
7578 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7579 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7580 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7582 "array attr of size out of bounds of vector result dimension size");
7583 if (resultScalableDims[index] && maskDimSize != 0 &&
7584 maskDimSize != resultShape[index])
7586 "only supports 'none set' or 'all set' scalable dimensions");
7590 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7591 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7592 if (anyZeros && !allZeros)
7593 return emitOpError(
"expected all mask dim sizes to be zeros, "
7594 "as a result of conjunction with zero mask dim");
7598bool ConstantMaskOp::isAllOnesMask() {
7601 if (resultType.getRank() == 0) {
7602 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7603 return getMaskDimSizes()[0] == 1;
7605 for (
const auto [resultSize, maskDimSize] :
7606 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7607 if (maskDimSize < resultSize)
7613OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7614 ArrayRef<int64_t> bounds = getMaskDimSizes();
7617 auto createBoolSplat = [&](
bool x) {
7623 if (vectorSizes.empty()) {
7624 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7625 return createBoolSplat(bounds[0] == 1);
7628 if (bounds == vectorSizes)
7629 return createBoolSplat(
true);
7630 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7631 return createBoolSplat(
false);
7632 return OpFoldResult();
7639void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7641 ArrayRef<OpFoldResult> mixedOperands) {
7642 SmallVector<Value> operands =
7644 build(builder,
result, type, operands);
7647LogicalResult CreateMaskOp::verify() {
7648 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7650 if (vectorType.getRank() == 0) {
7651 if (getNumOperands() != 1)
7653 "must specify exactly one operand for 0-D create_mask");
7654 }
else if (getNumOperands() !=
7655 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7657 "must specify an operand for each result vector dimension");
7687class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7691 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7692 PatternRewriter &rewriter)
const override {
7693 VectorType maskType = createMaskOp.getVectorType();
7694 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7695 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7698 constexpr std::array<int64_t, 1> rankZeroShape{1};
7699 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7700 if (maskType.getRank() == 0) {
7701 maskTypeDimSizes = rankZeroShape;
7702 maskTypeDimScalableFlags = rankZeroScalableDims;
7707 SmallVector<int64_t, 4> constantDims;
7708 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7713 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7715 constantDims.push_back(*intSize);
7719 if (vscaleMultiplier < maskTypeDimSizes[i])
7721 constantDims.push_back(*vscaleMultiplier);
7728 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7729 value = std::clamp<int64_t>(value, 0, maskDimSize);
7732 if (llvm::is_contained(constantDims, 0))
7733 constantDims.assign(constantDims.size(), 0);
7744void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7745 MLIRContext *context) {
7746 results.
add<CreateMaskFolder>(context);
7754 OpBuilder &builder, OperationState &
result, Value mask,
7755 Operation *maskableOp,
7756 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7757 assert(maskRegionBuilder &&
7758 "builder callback for 'maskRegion' must be present");
7760 result.addOperands(mask);
7761 OpBuilder::InsertionGuard guard(builder);
7762 Region *maskRegion =
result.addRegion();
7764 maskRegionBuilder(builder, maskableOp);
7769 Value mask, Operation *maskableOp,
7770 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7771 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7777 Value mask, Value passthru, Operation *maskableOp,
7778 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7779 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7781 result.addOperands(passthru);
7782 result.addTypes(resultTypes);
7785ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7787 result.regions.reserve(1);
7788 Region &maskRegion = *
result.addRegion();
7793 OpAsmParser::UnresolvedOperand mask;
7798 OpAsmParser::UnresolvedOperand passthru;
7800 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7807 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7818 SmallVector<Type> resultTypes;
7821 result.types.append(resultTypes);
7827 if (parsePassthru.succeeded()) {
7828 if (resultTypes.empty())
7831 "expects a result if passthru operand is provided");
7840void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7841 p <<
" " << getMask();
7843 p <<
", " << getPassthru();
7847 Block *singleBlock = &getMaskRegion().getBlocks().front();
7854 p <<
" : " << getMask().getType();
7855 if (getNumResults() > 0)
7856 p <<
" -> " << getResultTypes();
7859void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7862 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7863 MaskOp>::ensureTerminator(region, builder, loc);
7869 if (isa<vector::YieldOp>(block.
back()))
7877 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7878 MaskOp>::ensureTerminator(region, builder, loc);
7884 Operation *maskedOp = &block.
front();
7885 opBuilder.setInsertionPointToEnd(&block);
7886 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7889LogicalResult MaskOp::verify() {
7891 Block &block = getMaskRegion().getBlocks().
front();
7893 return emitOpError(
"expects a terminator within the mask region");
7896 if (numMaskRegionOps > 2)
7897 return emitOpError(
"expects only one operation to mask");
7900 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7902 return emitOpError(
"expects a terminator within the mask region");
7904 if (terminator->getNumOperands() != getNumResults())
7906 "expects number of results to match mask region yielded values");
7909 if (numMaskRegionOps == 1)
7912 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7914 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7918 return emitOpError(
"expects number of results to match maskable operation "
7919 "number of results");
7921 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7922 return emitOpError(
"expects all the results from the MaskableOpInterface "
7923 "to match all the values returned by the terminator");
7925 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7927 "expects result type to match maskable operation result type");
7930 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7931 return emitOpError(
"multiple vector results not supported");
7934 Type expectedMaskType = maskableOp.getExpectedMaskType();
7935 if (getMask().
getType() != expectedMaskType)
7937 << expectedMaskType <<
" mask for the maskable operation";
7940 Value passthru = getPassthru();
7942 if (!maskableOp.supportsPassthru())
7944 "doesn't expect a passthru argument for this maskable operation");
7947 return emitOpError(
"expects result when passthru argument is provided");
7950 return emitOpError(
"expects passthru type to match result type");
7970static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7971 SmallVectorImpl<OpFoldResult> &results) {
7972 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7975 Block *block = maskOp.getMaskBlock();
7976 auto terminator = cast<vector::YieldOp>(block->
front());
7977 if (terminator.getNumOperands() == 0)
7981 llvm::append_range(results, terminator.getOperands());
7985LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7986 SmallVectorImpl<OpFoldResult> &results) {
7987 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7997 Operation *maskableOp = getMaskableOp();
8003 llvm::append_range(results, maskableOp->
getResults());
8019class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
8022 LogicalResult matchAndRewrite(MaskOp maskOp,
8023 PatternRewriter &rewriter)
const override {
8024 if (!maskOp.isEmpty())
8027 if (!maskOp.hasPassthru())
8034 VectorType maskType = maskOp.getMask().getType();
8035 for (Type resultType : maskOp.getResultTypes()) {
8036 auto vecResultType = dyn_cast<VectorType>(resultType);
8037 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8041 Block *block = maskOp.getMaskBlock();
8042 auto terminator = cast<vector::YieldOp>(block->
front());
8043 assert(terminator.getNumOperands() == 1 &&
8044 "expected one result when passthru is provided");
8047 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8048 terminator.getOperand(0), maskOp.getPassthru());
8054void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8055 MLIRContext *context) {
8056 results.
add<CanonializeEmptyMaskOp>(context);
8062Operation *MaskOp::getMaskableOp() {
8063 Block *block = getMaskBlock();
8067 return &block->
front();
8071bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
8077LogicalResult ScanOp::verify() {
8078 VectorType srcType = getSourceType();
8079 VectorType initialType = getInitialValueType();
8081 int64_t srcRank = srcType.getRank();
8082 int64_t reductionDim = getReductionDim();
8083 if (reductionDim >= srcRank)
8085 << reductionDim <<
" has to be less than " << srcRank;
8088 int64_t initialValueRank = initialType.getRank();
8089 if (initialValueRank != srcRank - 1)
8091 << initialValueRank <<
" has to be equal to " << srcRank - 1;
8094 ArrayRef<int64_t> srcShape = srcType.getShape();
8095 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8096 SmallVector<int64_t> expectedShape;
8097 for (
int i = 0; i < srcRank; i++) {
8098 if (i != reductionDim)
8099 expectedShape.push_back(srcShape[i]);
8101 if (!llvm::equal(initialValueShapes, expectedShape)) {
8102 return emitOpError(
"incompatible input/initial value shapes");
8106 Type eltType = getDestType().getElementType();
8109 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
8116 RewritePatternSet &patterns, PatternBenefit benefit) {
8118 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8119 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8120 StridedSliceConstantMaskFolder, TransposeFolder>(
8125 CombiningKind kind, Value v1, Value acc,
8126 arith::FastMathFlagsAttr fastmath,
8133 case CombiningKind::ADD:
8135 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
8136 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8137 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8139 llvm_unreachable(
"invalid value types for ADD reduction");
8141 case CombiningKind::AND:
8143 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
8145 case CombiningKind::MAXNUMF:
8146 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8147 "expected float values");
8148 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8150 case CombiningKind::MAXIMUMF:
8151 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8152 "expected float values");
8153 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8155 case CombiningKind::MINNUMF:
8156 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8157 "expected float values");
8158 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8160 case CombiningKind::MINIMUMF:
8161 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8162 "expected float values");
8163 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8165 case CombiningKind::MAXSI:
8167 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8169 case CombiningKind::MINSI:
8171 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8173 case CombiningKind::MAXUI:
8175 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8177 case CombiningKind::MINUI:
8179 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8181 case CombiningKind::MUL:
8183 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
8184 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8185 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8187 llvm_unreachable(
"invalid value types for MUL reduction");
8189 case CombiningKind::OR:
8191 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
8193 case CombiningKind::XOR:
8195 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8199 assert(
result &&
"unknown CombiningKind");
8207void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8209 auto resultType = cast<VectorType>(
getType());
8210 if (resultType.isScalable()) {
8214 APInt zero(bitwidth, 0);
8215 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8216 ConstantIntRanges
result = {zero, high, zero, high};
8217 setResultRanges(getResult(),
result);
8247struct StepCompareFolder :
public OpRewritePattern<StepOp> {
8250 LogicalResult matchAndRewrite(StepOp stepOp,
8251 PatternRewriter &rewriter)
const override {
8252 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8254 for (OpOperand &use : stepOp.getResult().getUses()) {
8255 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8260 const unsigned stepOperandNumber = use.getOperandNumber();
8261 if (stepOperandNumber != 0)
8265 unsigned constOperandNumber = 1;
8266 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8267 std::optional<int64_t> maybeConstValue =
8269 if (!maybeConstValue.has_value())
8272 int64_t constValue = maybeConstValue.value();
8273 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8275 auto maybeSplat = [&]() -> std::optional<bool> {
8277 if ((pred == arith::CmpIPredicate::ult ||
8278 pred == arith::CmpIPredicate::uge) &&
8279 stepSize <= constValue)
8280 return pred == arith::CmpIPredicate::ult;
8283 if ((pred == arith::CmpIPredicate::ule ||
8284 pred == arith::CmpIPredicate::ugt) &&
8285 stepSize - 1 <= constValue) {
8286 return pred == arith::CmpIPredicate::ule;
8290 if ((pred == arith::CmpIPredicate::eq ||
8291 pred == arith::CmpIPredicate::ne) &&
8292 stepSize <= constValue)
8293 return pred == arith::CmpIPredicate::ne;
8295 return std::nullopt;
8298 if (!maybeSplat.has_value())
8303 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8308 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8320void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8321 MLIRContext *context) {
8322 results.
add<StepCompareFolder>(context);
8332 Operation *maskableOp) {
8333 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
8345 Operation *maskableOp, Value mask,
8350 return MaskOp::create(builder, maskableOp->
getLoc(),
8353 return MaskOp::create(builder, maskableOp->
getLoc(),
8366 Value newValue, Value passthru) {
8370 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
8371 mask, newValue, passthru);
8382struct InterleaveDeinterleaveFolder :
public OpRewritePattern<InterleaveOp> {
8385 LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
8386 PatternRewriter &rewriter)
const override {
8387 auto lhsDefOp = interleaveOp.getLhs().getDefiningOp<DeinterleaveOp>();
8388 auto rhsDefOp = interleaveOp.getRhs().getDefiningOp<DeinterleaveOp>();
8389 if (!lhsDefOp || !rhsDefOp || lhsDefOp != rhsDefOp)
8391 for (
auto [idx, operand] : llvm::enumerate(interleaveOp.getOperands())) {
8392 if (cast<OpResult>(operand).getResultNumber() != idx)
8395 rewriter.
replaceOp(interleaveOp, lhsDefOp.getSource());
8401void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
8402 MLIRContext *context) {
8403 results.
add<InterleaveDeinterleaveFolder>(context);
8406std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
8407 return llvm::to_vector<4>(getResultVectorType().
getShape());
8414std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
8415 return llvm::to_vector<4>(getResultVectorType().
getShape());
8422#define GET_ATTRDEF_CLASSES
8423#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8425#define GET_OP_CLASSES
8426#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static LogicalResult foldSize1TransferPermutationMap(TransferOp op)
When the vector type is vector<1xT>, the permutation map is irrelevant: the single vector lane always...
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)