42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/Repeated.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/SmallVector.h"
46#include "llvm/ADT/SmallVectorExtras.h"
47#include "llvm/ADT/StringSet.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/Support/Casting.h"
55#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
57#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
78 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
80 for (
bool b : denseElts.getValues<
bool>())
83 else if (!
b && val <= 0)
97 auto shape = m.getType().getShape();
100 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
101 if (maskIdx < dimSize)
114 auto maskOperands = m.getOperands();
115 for (
Value operand : maskOperands) {
116 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
118 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
131 vector::YieldOp::create(builder, loc);
137 switch (combiningKind) {
138 case CombiningKind::ADD:
139 case CombiningKind::MUL:
141 case CombiningKind::MINUI:
142 case CombiningKind::MINSI:
143 case CombiningKind::MAXUI:
144 case CombiningKind::MAXSI:
145 case CombiningKind::AND:
146 case CombiningKind::OR:
147 case CombiningKind::XOR:
149 case CombiningKind::MINNUMF:
150 case CombiningKind::MAXNUMF:
151 case CombiningKind::MINIMUMF:
152 case CombiningKind::MAXIMUMF:
153 return llvm::isa<FloatType>(elementType);
183 VectorType vectorType) {
184 unsigned elementVectorRank = 0;
185 VectorType elementVectorType =
186 llvm::dyn_cast<VectorType>(shapedType.getElementType());
187 if (elementVectorType)
188 elementVectorRank += elementVectorType.getRank();
189 return vectorType.getRank() - elementVectorRank;
193 VectorType vectorType) {
196 if (shapedType.getRank() == 0 &&
202 shapedType.getRank(),
204 shapedType.getContext());
211 vector::TransferReadOp read) {
212 auto readMask = read.getMask();
213 auto writeMask = write.getMask();
219 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
220 if (!couldBeSameSplat)
237 vector::TransferReadOp read) {
238 return !defWrite.hasOutOfBoundsDim() &&
239 defWrite.getIndices() == read.getIndices() &&
240 defWrite.getVectorType() == read.getVectorType() &&
241 defWrite.getPermutationMap() == read.getPermutationMap() &&
242 ((!defWrite.getMask() && !read.getMask()) ||
247 vector::TransferWriteOp priorWrite) {
248 return priorWrite.getIndices() == write.getIndices() &&
249 priorWrite.getMask() == write.getMask() &&
250 priorWrite.getVectorType() == write.getVectorType() &&
251 priorWrite.getPermutationMap() == write.getPermutationMap();
255 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
256 bool testDynamicValueUsingBounds) {
258 if (transferA.getVectorType() != transferB.getVectorType())
260 unsigned rankOffset = transferA.getLeadingShapedRank();
261 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
262 Value indexA = transferA.getIndices()[i];
263 Value indexB = transferB.getIndices()[i];
267 if (i < rankOffset) {
270 if (cstIndexA.has_value() && cstIndexB.has_value()) {
271 if (*cstIndexA != *cstIndexB)
275 if (testDynamicValueUsingBounds) {
278 FailureOr<uint64_t> delta =
280 if (succeeded(delta) && *delta != 0)
283 FailureOr<bool> testEqual =
285 if (succeeded(testEqual) && !testEqual.value())
291 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
292 if (cstIndexA.has_value() && cstIndexB.has_value()) {
293 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
294 if (distance >= vectorDim)
298 if (testDynamicValueUsingBounds) {
301 FailureOr<int64_t> delta =
303 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
306 FailureOr<int64_t> computeDelta =
308 if (succeeded(computeDelta)) {
309 if (std::abs(computeDelta.value()) >= vectorDim)
319 VectorTransferOpInterface transferB,
320 bool testDynamicValueUsingBounds) {
321 if (transferA.getBase() != transferB.getBase())
324 testDynamicValueUsingBounds);
334 for (
auto [posInDim, dimSize, offsetInDim] :
335 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
337 if (posInDim < dimSize + offsetInDim)
341 posInDim = offsetInDim;
351 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
353 assert(constOp &&
"Unexpected non-constant index");
354 return constOp.value();
364 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
365 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
366 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
376 llvm::transform(foldResults, std::back_inserter(values),
378 if (
auto attr = dyn_cast<Attribute>(foldResult))
380 builder, loc, cast<IntegerAttr>(attr).getInt())
383 return cast<Value>(foldResult);
396 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
398 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
408 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
409 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
410 if (intAttr.getType() != expectedType)
411 return IntegerAttr::get(expectedType, intAttr.getInt());
417 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
418 auto intType = dyn_cast<IntegerType>(expectedType);
422 APFloat floatVal = floatAttr.getValue();
423 APInt intVal = floatVal.bitcastToAPInt();
424 return IntegerAttr::get(expectedType, intVal);
463struct VectorInlinerInterface :
public DialectInlinerInterface {
464 using DialectInlinerInterface::DialectInlinerInterface;
473void VectorDialect::initialize() {
475#define GET_ATTRDEF_LIST
476#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
481#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
484 addInterfaces<VectorInlinerInterface>();
486 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
487 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
489 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
491 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
492 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
493 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
504 return arith::ConstantOp::materialize(builder, value, type, loc);
520void vector::MultiDimReductionOp::build(
OpBuilder &builder,
523 CombiningKind kind) {
525 for (
const auto &en : llvm::enumerate(reductionMask))
527 reductionDims.push_back(en.index());
528 build(builder,
result, kind, source,
acc, reductionDims);
531OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
533 if (getReductionDims().empty())
538std::optional<SmallVector<int64_t, 4>>
539MultiDimReductionOp::getShapeForUnroll() {
540 return llvm::to_vector<4>(getSourceVectorType().
getShape());
543LogicalResult MultiDimReductionOp::verify() {
546 Type inferredReturnType;
547 auto sourceScalableDims = getSourceVectorType().getScalableDims();
548 for (
auto [dimIdx, dimSize] :
549 llvm::enumerate(getSourceVectorType().
getShape()))
550 if (!llvm::any_of(getReductionDims(),
551 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
552 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
554 targetShape.push_back(dimSize);
555 scalableDims.push_back(sourceScalableDims[dimIdx]);
558 if (targetShape.empty())
559 inferredReturnType = getSourceVectorType().getElementType();
561 inferredReturnType = VectorType::get(
562 targetShape, getSourceVectorType().
getElementType(), scalableDims);
563 if (
getType() != inferredReturnType)
565 <<
" is incompatible with source type "
566 << getSourceVectorType();
572Type MultiDimReductionOp::getExpectedMaskType() {
573 auto vecType = getSourceVectorType();
574 return VectorType::get(vecType.getShape(),
575 IntegerType::get(vecType.getContext(), 1),
576 vecType.getScalableDims());
585struct ElideUnitDimsInMultiDimReduction
589 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
590 PatternRewriter &rewriter)
const override {
591 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
592 for (
const auto &dim :
enumerate(shape)) {
593 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
598 OpBuilder::InsertionGuard guard(rewriter);
601 if (reductionOp.isMasked()) {
603 rootOp = reductionOp.getMaskingOp();
604 mask = reductionOp.getMaskingOp().getMask();
606 rootOp = reductionOp;
609 Location loc = reductionOp.getLoc();
610 Value acc = reductionOp.getAcc();
612 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
614 VectorType newMaskType =
615 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
616 dstVecType.getScalableDims());
617 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
619 cast = vector::ShapeCastOp::create(
620 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
625 mask = vector::ExtractOp::create(rewriter, loc, mask);
626 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
631 cast,
nullptr, mask);
638void MultiDimReductionOp::getCanonicalizationPatterns(
640 results.
add<ElideUnitDimsInMultiDimReduction>(context);
649 arith::FastMathFlags fastMathFlags) {
655 arith::FastMathFlags fastMathFlags) {
657 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
661LogicalResult ReductionOp::verify() {
663 int64_t rank = getSourceVectorType().getRank();
665 return emitOpError(
"unsupported reduction rank: ") << rank;
668 Type eltType = getDest().getType();
671 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
680Type ReductionOp::getExpectedMaskType() {
681 auto vecType = getSourceVectorType();
682 return VectorType::get(vecType.getShape(),
683 IntegerType::get(vecType.getContext(), 1),
684 vecType.getScalableDims());
691 case arith::AtomicRMWKind::addf:
692 case arith::AtomicRMWKind::addi:
693 return vector::ReductionOp::create(builder,
vector.getLoc(),
694 CombiningKind::ADD,
vector);
695 case arith::AtomicRMWKind::mulf:
696 case arith::AtomicRMWKind::muli:
697 return vector::ReductionOp::create(builder,
vector.getLoc(),
698 CombiningKind::MUL,
vector);
699 case arith::AtomicRMWKind::minimumf:
700 return vector::ReductionOp::create(builder,
vector.getLoc(),
701 CombiningKind::MINIMUMF,
vector);
702 case arith::AtomicRMWKind::mins:
703 return vector::ReductionOp::create(builder,
vector.getLoc(),
704 CombiningKind::MINSI,
vector);
705 case arith::AtomicRMWKind::minu:
706 return vector::ReductionOp::create(builder,
vector.getLoc(),
707 CombiningKind::MINUI,
vector);
708 case arith::AtomicRMWKind::maximumf:
709 return vector::ReductionOp::create(builder,
vector.getLoc(),
710 CombiningKind::MAXIMUMF,
vector);
711 case arith::AtomicRMWKind::maxs:
712 return vector::ReductionOp::create(builder,
vector.getLoc(),
713 CombiningKind::MAXSI,
vector);
714 case arith::AtomicRMWKind::maxu:
715 return vector::ReductionOp::create(builder,
vector.getLoc(),
716 CombiningKind::MAXUI,
vector);
717 case arith::AtomicRMWKind::andi:
718 return vector::ReductionOp::create(builder,
vector.getLoc(),
719 CombiningKind::AND,
vector);
720 case arith::AtomicRMWKind::ori:
721 return vector::ReductionOp::create(builder,
vector.getLoc(),
722 CombiningKind::OR,
vector);
723 case arith::AtomicRMWKind::minnumf:
724 return vector::ReductionOp::create(builder,
vector.getLoc(),
725 CombiningKind::MINNUMF,
vector);
726 case arith::AtomicRMWKind::maxnumf:
727 return vector::ReductionOp::create(builder,
vector.getLoc(),
728 CombiningKind::MAXNUMF,
vector);
729 case arith::AtomicRMWKind::xori:
730 return vector::ReductionOp::create(builder,
vector.getLoc(),
731 CombiningKind::XOR,
vector);
739std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
740 return llvm::to_vector<4>(getSourceVectorType().
getShape());
747 LogicalResult matchAndRewrite(ReductionOp reductionOp,
752 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
755 if (maskableOp.isMasked()) {
757 rootOp = maskableOp.getMaskingOp();
758 mask = maskableOp.getMaskingOp().getMask();
760 rootOp = reductionOp;
763 auto vectorType = reductionOp.getSourceVectorType();
764 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
767 Location loc = reductionOp.getLoc();
769 mask = ExtractOp::create(rewriter, loc, mask);
770 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
772 if (Value acc = reductionOp.getAcc())
775 reductionOp.getFastmathAttr(), mask);
785 results.
add<ElideSingleElementReduction>(context);
799 getIndexingMapsAttrName(
result.name),
803 getIteratorTypesAttrName(
result.name),
806 return IteratorTypeAttr::get(builder.getContext(), t);
815 ContractionOp::getDefaultKind());
821 ArrayAttr iteratorTypes, CombiningKind kind,
822 arith::FastMathFlags fastMathFlags) {
825 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
826 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
828 CombiningKindAttr::get(builder.
getContext(), kind));
829 if (fastMathFlags != arith::FastMathFlags::none)
831 getFastmathAttrName(
result.name),
832 arith::FastMathFlagsAttr::get(builder.
getContext(), fastMathFlags));
843 DictionaryAttr dictAttr;
857 result.attributes.append(dictAttr.getValue().begin(),
858 dictAttr.getValue().end());
864 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
865 result.attributes.get(getIteratorTypesAttrName(
result.name)));
866 if (!iteratorTypes) {
868 <<
"expected " << getIteratorTypesAttrName(
result.name)
869 <<
" array attribute";
874 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
875 auto maybeIteratorType = symbolizeIteratorType(s);
876 if (!maybeIteratorType.has_value())
877 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
879 iteratorTypeAttrs.push_back(
880 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
882 result.attributes.set(getIteratorTypesAttrName(
result.name),
885 if (!
result.attributes.get(getKindAttrName(
result.name))) {
887 getKindAttrName(
result.name),
888 CombiningKindAttr::get(
result.getContext(),
889 ContractionOp::getDefaultKind()));
891 if (masksInfo.empty())
893 if (masksInfo.size() != 2)
895 "expected zero or exactly 2 vector mask operands");
896 auto lhsType = llvm::cast<VectorType>(types[0]);
897 auto rhsType = llvm::cast<VectorType>(types[1]);
899 std::array<VectorType, 2> maskTypes = {
909 auto attrNames = getTraitAttrNames();
911 traitAttrsSet.insert_range(attrNames);
913 for (
auto attr : (*this)->getAttrs()) {
914 if (attr.getName() == getIteratorTypesAttrName()) {
916 llvm::cast<ArrayAttr>(attr.getValue())
917 .getAsValueRange<IteratorTypeAttr, IteratorType>();
923 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) ->
Attribute {
924 return StringAttr::get(
getContext(), stringifyIteratorType(t));
927 attrs.emplace_back(getIteratorTypesAttrName(),
928 ArrayAttr::get(
getContext(), iteratorTypeNames));
929 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
931 if (attr.getName() == getFastmathAttrName() &&
932 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
933 arith::FastMathFlags::none)
935 attrs.push_back(attr);
939 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
940 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
941 p << getRhs() <<
", " << getAcc();
944 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
949 const std::vector<std::pair<int64_t, int64_t>> &map) {
950 for (
auto &dimPair : map) {
951 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
952 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
953 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
960 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
962 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
963 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
966 for (
auto &dimPair : contractingDimMap) {
967 lhsContractingDimSet.insert(dimPair.first);
968 rhsContractingDimSet.insert(dimPair.second);
971 llvm::make_second_range(batchDimMap));
975 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
976 if (lhsContractingDimSet.count(i) > 0)
978 expectedResultDims.push_back(lhsType.getDimSize(i));
982 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
983 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
985 expectedResultDims.push_back(rhsType.getDimSize(i));
989 if (expectedResultDims.empty()) {
991 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
992 return op.emitOpError(
"invalid accumulator/result vector shape");
995 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
996 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
997 if (!resVectorType || !accVectorType)
998 return op.emitOpError(
"invalid accumulator/result vector shape");
1004 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1005 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1007 return op.emitOpError(
1008 "expected all dimensions to be either a LHS or a RHS dimension");
1011 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1012 VectorType v = pair.first;
1013 auto map = pair.second;
1014 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1015 unsigned pos = map.getDimPosition(idx);
1020 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1021 return op.emitOpError(
"expected all dimensions to get an extent as "
1022 "either a LHS or a RHS dimension");
1024 AffineMap resMap = op.getIndexingMapsArray()[2];
1029 assert(llvm::all_of(expectedMap.
getResults(),
1030 llvm::IsaPred<AffineConstantExpr>) &&
1031 "expected constant extent along all dimensions.");
1033 auto expectedShape =
1035 return cast<AffineConstantExpr>(e).getValue();
1038 VectorType::get(expectedShape, resVectorType.getElementType(),
1039 resVectorType.getScalableDims());
1040 if (resVectorType != expected || accVectorType != expected)
1041 return op.emitOpError(
1042 "invalid accumulator/result vector shape, expected: ")
1048LogicalResult ContractionOp::verify() {
1049 VectorType lhsType = getLhsType();
1050 VectorType rhsType = getRhsType();
1051 Type accType = getAccType();
1052 Type resType = getResultType();
1054 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1055 if (!lhsType.getElementType().isSignlessInteger())
1056 return emitOpError(
"only supports signless integer types");
1060 if (getIndexingMapsArray().size() != 3)
1061 return emitOpError(
"expected an indexing map for each vector operand");
1066 unsigned numIterators = getIteratorTypes().getValue().size();
1067 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1068 auto index = it.index();
1069 auto map = it.value();
1070 if (map.getNumSymbols() != 0)
1072 <<
index <<
" to have no symbols";
1073 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1074 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1077 if (map.getNumDims() != numIterators)
1079 <<
index <<
" to have " << numIterators <<
" number of inputs";
1080 if (map.getNumResults() != rank)
1082 <<
index <<
" to have " << rank <<
" number of outputs";
1083 if (!map.isProjectedPermutation())
1085 <<
index <<
" to be a projected permutation of its inputs";
1088 auto contractingDimMap = getContractingDimMap();
1089 auto batchDimMap = getBatchDimMap();
1092 if (contractingDimMap.empty())
1093 return emitOpError(
"expected at least one contracting dimension pair");
1096 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1097 return emitOpError(
"invalid contracting dimension map");
1101 return emitOpError(
"invalid batch dimension map");
1105 contractingDimMap, batchDimMap)))
1108 if (!getKindAttr()) {
1109 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1110 "'vector.kind<add>')");
1114 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1115 auto elementType = vectorType ? vectorType.getElementType() : resType;
1117 return emitOpError(
"unsupported contraction type");
1120 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1127Type ContractionOp::getExpectedMaskType() {
1128 auto indexingMaps = this->getIndexingMapsArray();
1131 VectorType lhsType = this->getLhsType();
1132 VectorType rhsType = this->getRhsType();
1134 unsigned numVecDims = lhsIdxMap.
getNumDims();
1140 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1143 lhsType.getScalableDims()[dimIdx];
1145 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1148 rhsType.getScalableDims()[dimIdx];
1151 assert(ShapedType::isStaticShape(maskShape) &&
1152 "Mask shape couldn't be computed");
1154 return VectorType::get(maskShape,
1155 IntegerType::get(lhsType.getContext(), 1),
1156 maskShapeScalableDims);
1161 getIteratorTypesAttrName(), getKindAttrName(),
1162 getFastmathAttrName()};
1172static std::vector<std::pair<int64_t, int64_t>>
1174 IteratorType targetIteratorType,
MLIRContext *context) {
1175 std::vector<std::pair<int64_t, int64_t>> dimMap;
1176 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1177 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1178 if (iteratorType != targetIteratorType)
1184 if (lhsDim >= 0 && rhsDim >= 0)
1185 dimMap.emplace_back(lhsDim, rhsDim);
1190void ContractionOp::getIterationBounds(
1192 auto lhsShape = getLhsType().getShape();
1193 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1195 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1198 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1199 if (iteratorType == IteratorType::reduction) {
1202 assert(lhsDimIndex >= 0);
1203 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1208 assert(resDimIndex >= 0);
1209 assert(resVectorType !=
nullptr);
1210 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1214void ContractionOp::getIterationIndexMap(
1216 unsigned numMaps = getIndexingMapsArray().size();
1217 iterationIndexMap.resize(numMaps);
1218 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1219 auto index = it.index();
1220 auto map = it.value();
1221 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1222 auto dim = cast<AffineDimExpr>(map.getResult(i));
1223 iterationIndexMap[
index][dim.getPosition()] = i;
1228std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1230 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1234std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1236 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1240std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1242 getIterationBounds(
shape);
1264template <
typename AddOpType>
1270 auto canonicalize = [&](
Value maybeContraction,
1271 Value otherOperand) -> vector::ContractionOp {
1272 vector::ContractionOp contractionOp =
1273 dyn_cast_or_null<vector::ContractionOp>(
1276 return vector::ContractionOp();
1277 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1278 contractionOp.getAcc().getDefiningOp())) {
1279 if (maybeZero.getValue() ==
1280 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1282 bvm.
map(contractionOp.getAcc(), otherOperand);
1283 auto newContraction =
1284 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1285 rewriter.
replaceOp(addOp, newContraction.getResult());
1286 return newContraction;
1289 return vector::ContractionOp();
1292 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1293 vector::ContractionOp
contract = canonicalize(a,
b);
1318 setResultRanges(getResult(), argRanges.front());
1323 auto vectorTy = cast<VectorType>(source.
getType());
1348 build(builder,
result, source, dynamicPos,
1353ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1354 ExtractOp::Adaptor adaptor,
1356 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1357 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1358 vectorType.getRank()) {
1359 inferredReturnTypes.push_back(vectorType.getElementType());
1361 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1362 vectorType.getRank());
1363 inferredReturnTypes.push_back(VectorType::get(
1364 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1365 vectorType.getScalableDims().drop_front(n)));
1370LogicalResult vector::ExtractOp::verify() {
1371 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1372 if (resTy.getRank() == 0)
1374 "expected a scalar instead of a 0-d vector as the result type");
1377 auto dynamicMarkersCount =
1378 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1379 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1381 "mismatch between dynamic and static positions (kDynamic marker but no "
1382 "corresponding dynamic position) -- this can only happen due to an "
1383 "incorrect fold/rewrite");
1384 auto position = getMixedPosition();
1385 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1387 "expected position attribute of rank no greater than vector rank");
1388 for (
auto [idx, pos] : llvm::enumerate(position)) {
1389 if (
auto attr = dyn_cast<Attribute>(pos)) {
1390 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1392 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1393 return emitOpError(
"expected position attribute #")
1395 <<
" to be a non-negative integer smaller than the "
1396 "corresponding vector dimension or poison (-1)";
1403template <
typename IntType>
1405 return llvm::map_to_vector<4>(
1406 arrayAttr.getAsRange<IntegerAttr>(),
1407 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1413 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1417 if (extractOp.hasDynamicPosition())
1421 ExtractOp currentOp = extractOp;
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1427 if (currentOp.hasDynamicPosition())
1430 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1432 extractOp.setOperand(0, currentOp.getSource());
1435 std::reverse(globalPosition.begin(), globalPosition.end());
1436 extractOp.setStaticPosition(globalPosition);
1448class ExtractFromInsertTransposeChainState {
1450 ExtractFromInsertTransposeChainState(ExtractOp e);
1459 template <
typename ContainerA,
typename ContainerB>
1460 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1461 return a.size() <=
b.size() &&
1462 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1469 template <
typename ContainerA,
typename ContainerB>
1470 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1471 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1472 if (elemA < 0 || elemB < 0)
1483 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1487 void updateStateForNextIteration(Value v) {
1494 LogicalResult handleTransposeOp();
1497 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1512 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1517 Value tryToFoldExtractOpInPlace(Value source);
1519 ExtractOp extractOp;
1521 int64_t extractedRank;
1523 InsertOp nextInsertOp;
1524 TransposeOp nextTransposeOp;
1534 SmallVector<int64_t> sentinels;
1535 SmallVector<int64_t> extractPosition;
1539ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1541 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1542 extractedRank(extractOp.getNumIndices()) {
1543 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1544 sentinels.reserve(vectorRank - extractedRank);
1545 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1546 sentinels.push_back(-(i + 1));
1548 extractOp.getStaticPosition().end());
1554LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1556 if (extractOp.hasDynamicPosition())
1559 if (!nextTransposeOp)
1562 nextTransposeOp.getPermutation(), extractOp.getContext()));
1569ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1572 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1575 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1576 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1579 res = nextInsertOp.getValueToStore();
1588ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1590 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1593 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1603 res = nextInsertOp.getValueToStore();
1611Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1614 if (extractOp.hasDynamicPosition())
1618 bool nothingToFold = (source == extractOp.getSource());
1619 if (nothingToFold || !canFold())
1623 OpBuilder
b(extractOp.getContext());
1624 extractOp.setStaticPosition(
1626 extractOp.getSourceMutable().assign(source);
1627 return extractOp.getResult();
1631Value ExtractFromInsertTransposeChainState::fold() {
1633 if (extractOp.hasDynamicPosition())
1636 Value valueToExtractFrom = extractOp.getSource();
1637 updateStateForNextIteration(valueToExtractFrom);
1638 while (nextInsertOp || nextTransposeOp) {
1641 if (succeeded(handleTransposeOp())) {
1642 valueToExtractFrom = nextTransposeOp.getVector();
1643 updateStateForNextIteration(valueToExtractFrom);
1649 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1654 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1655 return tryToFoldExtractOpInPlace(
result);
1659 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1665 valueToExtractFrom = nextInsertOp.getDest();
1666 updateStateForNextIteration(valueToExtractFrom);
1669 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1674 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1675 auto vecType = dyn_cast<VectorType>(type);
1676 return vecType && vecType.getRank() == 0;
1686 if (isa<BroadcastOp>(op))
1689 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1697 VectorType srcType = shapeCast.getSourceVectorType();
1699 uint64_t srcRank = srcType.getRank();
1701 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1727 Operation *defOp = extractOp.getSource().getDefiningOp();
1734 if (extractOp.getType() == input.
getType())
1740 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1741 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1742 unsigned inputRank = inputType ? inputType.getRank() : 0;
1743 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1744 unsigned extractRank = extractType ? extractType.getRank() : 0;
1747 if (extractRank > inputRank)
1751 assert(inputType &&
"input must be a vector type because of previous checks");
1760 extractType.getShape() != inputShape.take_back(extractRank))
1765 unsigned deltaOverall = inputRank - extractRank;
1766 unsigned deltaBroadcast = broadcastRank - inputRank;
1770 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1771 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1774 extractOp->setOperands(
1775 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1776 extractOp.setStaticPosition(staticPos);
1777 return extractOp.getResult();
1793 if (extractOp.hasDynamicPosition())
1796 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1801 if (shuffleOp.getResultVectorType().getRank() != 1)
1804 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1805 auto shuffleMask = shuffleOp.getMask();
1806 int64_t extractIdx = extractOp.getStaticPosition()[0];
1807 int64_t shuffleIdx = shuffleMask[extractIdx];
1810 if (shuffleIdx < inputVecSize) {
1811 extractOp.setOperand(0, shuffleOp.getV1());
1812 extractOp.setStaticPosition({shuffleIdx});
1814 extractOp.setOperand(0, shuffleOp.getV2());
1815 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1818 return extractOp.getResult();
1824 if (extractOp.hasDynamicPosition())
1827 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1832 auto getDimReverse = [](VectorType type,
int64_t n) {
1833 return type.getShape().take_back(n + 1).front();
1836 llvm::isa<VectorType>(extractOp.getType())
1837 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1839 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1841 if (destinationRank > 0) {
1842 auto destinationType =
1843 llvm::cast<VectorType>(extractOp.getResult().getType());
1844 for (
int64_t i = 0; i < destinationRank; i++) {
1848 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1849 getDimReverse(destinationType, i))
1856 std::reverse(extractedPos.begin(), extractedPos.end());
1859 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1860 strides.push_back(stride);
1862 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1870 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1872 for (
int64_t i = 0; i < numDimension; i++) {
1873 newStrides.push_back(stride);
1875 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1877 std::reverse(newStrides.begin(), newStrides.end());
1881 extractOp.setStaticPosition(newPosition);
1882 extractOp.setOperand(0, shapeCastOp.getSource());
1883 return extractOp.getResult();
1889 if (extractOp.hasDynamicPosition())
1892 auto extractStridedSliceOp =
1893 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1894 if (!extractStridedSliceOp)
1903 if (extractStridedSliceOp.hasNonUnitStrides())
1909 while (!sliceOffsets.empty()) {
1910 size_t lastOffset = sliceOffsets.size() - 1;
1911 if (sliceOffsets.back() != 0 ||
1912 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1913 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1915 sliceOffsets.pop_back();
1917 unsigned destinationRank = 0;
1918 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1919 destinationRank = vecType.getRank();
1922 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1923 sliceOffsets.size())
1927 assert(extractedPos.size() >= sliceOffsets.size());
1928 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1929 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1930 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1934 extractOp.setStaticPosition(extractedPos);
1935 return extractOp.getResult();
1941 if (extractOp.hasDynamicPosition())
1945 llvm::isa<VectorType>(extractOp.getType())
1946 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1948 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1958 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1959 insertOp.getSourceVectorType().getRank();
1960 if (destinationRank > insertOp.getSourceVectorType().getRank())
1965 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1966 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1969 bool disjoint =
false;
1971 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1972 int64_t start = insertOffsets[dim];
1974 (dim < insertRankDiff)
1976 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1978 int64_t offset = extractOffsets[dim];
1980 if (start <= offset && offset < end) {
1981 if (dim >= insertRankDiff)
1982 offsetDiffs.push_back(offset - start);
1993 insertOp.getSourceVectorType().getRank() - destinationRank;
1994 for (
int64_t i = 0; i < destinationRank; i++) {
1995 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1996 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2000 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2003 extractOp.setStaticPosition(offsetDiffs);
2004 return extractOp.getResult();
2008 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2021 if (extractOp.hasDynamicPosition())
2025 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2026 if (!fromElementsOp)
2030 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2031 if (vecType.isScalable())
2035 int64_t rank = vecType.getRank();
2037 if (extractOp.getType() != vecType.getElementType())
2040 "unexpected number of indices");
2045 for (
int i = rank - 1; i >= 0; --i) {
2046 flatIndex +=
indices[i] * stride;
2047 stride *= vecType.getDimSize(i);
2049 return fromElementsOp.getElements()[flatIndex];
2054template <
typename OpType,
typename AdaptorType>
2057 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2058 OperandRange dynamicPosition = op.getDynamicPosition();
2061 if constexpr (std::is_same_v<OpType, ExtractOp>)
2062 vectorShape = op.getSourceVectorType().getShape();
2067 if (!dynamicPosition.size())
2074 bool opChange =
false;
2075 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2076 if (ShapedType::isStatic(staticPosition[i]))
2080 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2081 int64_t value = attr.getInt();
2085 staticPosition[i] = attr.getInt();
2090 operands.push_back(position);
2094 op.setStaticPosition(staticPosition);
2095 op.getOperation()->setOperands(operands);
2097 return op.getResult();
2107 if (!is_contained(staticPos, poisonVal))
2110 return ub::PoisonAttr::get(context);
2124 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2129 if (denseAttr.isSplat()) {
2131 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2136 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2137 if (vecTy.isScalable())
2140 if (extractOp.hasDynamicPosition()) {
2155 copy(extractOp.getStaticPosition(), completePositions.begin());
2158 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2161 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2163 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2166 newAttr = *denseValuesBegin;
2172OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2176 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2183 SmallVector<Value> operands = {getSource()};
2187 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2193 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2208 return inplaceFolded;
2214class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2218 LogicalResult matchAndRewrite(ExtractOp extractOp,
2219 PatternRewriter &rewriter)
const override {
2222 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2228 BroadcastableToResult::Success)
2237class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2241 LogicalResult matchAndRewrite(ExtractOp extractOp,
2242 PatternRewriter &rewriter)
const override {
2244 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2248 VectorType extractedMaskType =
2249 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2251 if (!extractedMaskType)
2254 auto maskOperands = createMaskOp.getOperands();
2255 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2256 VectorType maskType = createMaskOp.getVectorType();
2258 bool containsUnknownDims =
false;
2261 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2263 int64_t pos = extractOpPos[dimIdx];
2264 Value operand = maskOperands[dimIdx];
2265 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2268 containsUnknownDims =
true;
2272 int64_t createMaskBound =
2273 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2275 if (pos != ShapedType::kDynamic) {
2278 allFalse |= pos >= createMaskBound;
2279 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2283 containsUnknownDims =
true;
2290 }
else if (!containsUnknownDims) {
2292 extractOp, extractedMaskType,
2293 maskOperands.drop_front(extractOpPos.size()));
2302class ExtractOpFromConstantMask final :
public OpRewritePattern<ExtractOp> {
2306 LogicalResult matchAndRewrite(ExtractOp extractOp,
2307 PatternRewriter &rewriter)
const override {
2308 auto constantMaskOp =
2309 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2310 if (!constantMaskOp)
2313 Type resultType = extractOp.getResult().getType();
2314 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2316 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2317 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2319 VectorType maskType = constantMaskOp.getVectorType();
2322 for (
size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2323 int64_t pos = extractOpPos[dimIdx];
2324 if (pos == ShapedType::kDynamic) {
2327 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2336 if (pos >= maskDimSizes[dimIdx]) {
2337 if (extractedMaskType) {
2349 if (extractedMaskType) {
2353 extractOp, extractedMaskType,
2354 maskDimSizes.drop_front(extractOpPos.size()));
2367LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2368 PatternRewriter &rewriter) {
2369 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2373 VectorType sourceType = castOp.getSourceVectorType();
2374 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2378 if (sourceType.getNumElements() != targetType.getNumElements())
2382 castOp.getSource());
2392LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2393 PatternRewriter &rewriter) {
2395 if (extractOp.hasDynamicPosition())
2399 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2404 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2405 if (!fromElementsOp)
2407 VectorType inputType = fromElementsOp.getType();
2410 if (resultType.isScalable() || inputType.isScalable())
2415 SmallVector<int64_t> firstElementPos =
2416 llvm::to_vector(extractOp.getStaticPosition());
2417 firstElementPos.append(resultType.getRank(), 0);
2420 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2421 flatIndex += firstElementPos[i] * stride;
2422 stride *= inputType.getDimSize(i);
2427 extractOp, resultType,
2428 fromElementsOp.getElements().slice(flatIndex,
2429 resultType.getNumElements()));
2441struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2443 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2444 PatternRewriter &rewriter)
const override {
2445 VectorType sourceType = extractOp.getSourceVectorType();
2446 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2450 if (sourceType.getNumElements() != outType.getNumElements())
2452 extractOp,
"extract to vector with fewer elements");
2456 if (llvm::any_of(extractOp.getMixedPosition(),
2457 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2459 "leaving for extract poison folder");
2462 extractOp.getSource());
2483struct FoldExtractFromInsertUnitDim final
2484 : OpRewritePattern<vector::ExtractOp> {
2487 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2488 PatternRewriter &rewriter)
const override {
2489 if (extractOp.hasDynamicPosition())
2492 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2493 if (!insertOp || insertOp.hasDynamicPosition())
2496 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2497 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2500 if (extractPos.size() >= insertPos.size() ||
2501 extractPos != insertPos.take_front(extractPos.size()))
2507 auto srcVecType = extractOp.getSourceVectorType();
2508 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2509 if (srcVecType.getDimSize(i) != 1)
2512 Value
inserted = insertOp.getValueToStore();
2513 Type extractedType = extractOp.getResult().getType();
2514 if (isa<VectorType>(
inserted.getType())) {
2521 extractOp, extractOp.getResult().
getType(),
2522 insertOp.getValueToStore());
2530void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2531 MLIRContext *context) {
2532 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2533 ExtractOpFromConstantMask, ExtractToShapeCast,
2534 FoldExtractFromInsertUnitDim>(context);
2535 results.
add(foldExtractFromShapeCastToShapeCast);
2536 results.
add(foldExtractFromFromElements);
2541 for (
auto attr : arrayAttr)
2542 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2549std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2560 if (operands.empty())
2563 return llvm::all_of(operands, [&](
Value operand) {
2565 return currentDef == defOp;
2583 auto fromElementsOp =
2584 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2585 if (!fromElementsOp)
2588 llvm::append_range(results, fromElementsOp.getElements());
2605 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2609 if (isa<VectorType>(bcastOp.getSource().getType()))
2612 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2614 Value scalar = bcastOp.getSource();
2615 results.assign(resultVecType.getNumElements(), scalar);
2619LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2620 SmallVectorImpl<OpFoldResult> &results) {
2625 if (
auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2626 setOperand(shapeCast.getSource());
2634ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2635 ToElementsOp::Adaptor adaptor,
2636 SmallVectorImpl<Type> &inferredReturnTypes) {
2637 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2638 Type elType = vecType.getElementType();
2639 inferredReturnTypes.append(vecType.getNumElements(), elType);
2661 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2666 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2670 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2675 int64_t dstRank = dstShape.size();
2676 int64_t srcRank = srcShape.size();
2679 auto srcElems = vector::ToElementsOp::create(
2680 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2682 int64_t dstCount = llvm::product_of(dstShape);
2685 replacements.reserve(dstCount);
2710 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2713 for (
int64_t k = 0; k < srcRank; ++k)
2714 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2717 replacements.push_back(srcElems.getResult(srcLin));
2720 rewriter.
replaceOp(toElementsOp, replacements);
2725void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2726 MLIRContext *context) {
2727 results.
add<ToElementsOfBroadcast>(context);
2747 OperandRange fromElemsOperands = fromElementsOp.getElements();
2748 if (fromElemsOperands.empty())
2751 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2759 Value toElementsInput = toElementsOp.getSource();
2760 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2761 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2762 return toElementsInput;
2782 if (llvm::any_of(elements, [](
Attribute attr) {
2788 auto destVecType = fromElementsOp.getDest().getType();
2789 auto destEltType = destVecType.getElementType();
2790 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2795 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2802OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2819 if (!llvm::all_equal(fromElementsOp.getElements()))
2822 fromElementsOp, fromElementsOp.getType(),
2823 fromElementsOp.getElements().front());
2851 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2855 if (fromElements.getType().getNumElements() == 1)
2866 for (
auto [insertIndex, element] :
2867 llvm::enumerate(fromElements.getElements())) {
2870 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2873 "element not from vector.extract");
2878 if (insertIndex == 0) {
2879 source = extractOp.getSource();
2880 }
else if (extractOp.getSource() != source) {
2882 "element from different vector");
2886 int64_t rank = position.size();
2887 assert(rank == source.getType().getRank() &&
2888 "scalar extract must have full rank position");
2899 if (insertIndex == 0) {
2900 const int64_t numElms = fromElements.getType().getNumElements();
2903 while (
index > 0 && position[
index - 1] == 0 &&
2904 numSuffixElms < numElms) {
2905 numSuffixElms *= source.getType().getDimSize(
index - 1);
2908 if (numSuffixElms != numElms) {
2910 fromElements,
"elements do not form a suffix of source");
2912 expectedPosition = llvm::to_vector(position);
2913 combinedPosition = position.drop_back(rank -
index);
2917 else if (expectedPosition != position) {
2919 fromElements,
"elements not in ascending order (static order)");
2921 increment(expectedPosition, source.getType().getShape());
2924 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2925 fromElements.getLoc(), source, combinedPosition);
2928 fromElements, fromElements.getType(), extracted);
2936 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2955void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2957 setResultRanges(getResult(), argRanges.front());
2960std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2961 return llvm::to_vector<4>(getResultVectorType().
getShape());
2966static llvm::SetVector<int64_t>
2969 int64_t rankDiff = dstShape.size() - srcShape.size();
2972 for (
auto [s1, s2] :
2973 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2975 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2983llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2985 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2988 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
3004Value BroadcastOp::createOrFoldBroadcastOp(
3005 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
3006 const llvm::SetVector<int64_t> &broadcastedDims) {
3007 assert(!dstShape.empty() &&
"unexpected empty dst shape");
3010 SmallVector<int64_t> checkShape;
3011 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
3012 if (broadcastedDims.contains(i))
3014 checkShape.push_back(dstShape[i]);
3016 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3017 "ill-formed broadcastedDims contains values not confined to "
3020 Location loc = value.
getLoc();
3022 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
3023 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3026 if (!srcVectorType) {
3027 assert(checkShape.empty() &&
3028 "ill-formed createOrFoldBroadcastOp arguments");
3029 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3032 assert(srcVectorType.getShape().equals(checkShape) &&
3033 "ill-formed createOrFoldBroadcastOp arguments");
3043 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3044 broadcastShape.reserve(dstShape.size());
3060 int64_t nextSrcShapeDim = broadcastedDims.size();
3061 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3062 if (broadcastedDims.contains(i)) {
3067 broadcastShape.push_back(dstShape[i]);
3068 permutation[i] = broadcastShape.size() - 1;
3074 permutation[i] = nextSrcShapeDim++;
3078 llvm::append_range(broadcastShape, srcVectorType.getShape());
3083 "unexpected \"dim-1\" broadcast");
3085 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3087 vector::BroadcastableToResult::Success &&
3088 "must be broadcastable");
3089 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3092 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3093 if (permutation[i] != i)
3094 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3100 Type srcType, VectorType dstVectorType,
3101 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3103 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3107 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3111 int64_t srcRank = srcVectorType.getRank();
3112 int64_t dstRank = dstVectorType.getRank();
3113 if (srcRank > dstRank)
3117 int64_t lead = dstRank - srcRank;
3118 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3121 bool foundMismatchingDims =
false;
3124 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3125 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3126 if (srcDim != 1 && srcDim != dstDim)
3127 foundMismatchingDims =
true;
3130 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3131 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3132 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3135 (srcDimScalableFlag != dstDimScalableFlag &&
3136 (srcDim != 1 || srcDimScalableFlag)))
3137 foundMismatchingDims =
true;
3139 if (foundMismatchingDims) {
3140 if (mismatchingDims !=
nullptr) {
3141 mismatchingDims->first.dim = srcDim;
3142 mismatchingDims->first.isScalable = srcDimScalableFlag;
3144 mismatchingDims->second.dim = dstDim;
3145 mismatchingDims->second.isScalable = dstDimScalableFlag;
3154LogicalResult BroadcastOp::verify() {
3155 std::pair<VectorDim, VectorDim> mismatchingDims;
3157 getSourceType(), getResultVectorType(), &mismatchingDims);
3161 return emitOpError(
"source rank higher than destination rank");
3164 << (mismatchingDims.first.isScalable ?
"[" :
"")
3165 << mismatchingDims.first.dim
3166 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
3167 << (mismatchingDims.second.isScalable ?
"[" :
"")
3168 << mismatchingDims.second.dim
3169 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3172 return emitOpError(
"source type is not a vector");
3173 llvm_unreachable(
"unexpected vector.broadcast op error");
3180 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3184 VectorType srcType = srcShapeCast.getSourceVectorType();
3185 VectorType destType = broadcastOp.getResultVectorType();
3193 srcShapeCast.getResultVectorType().getShape();
3196 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3197 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3198 shapecastShape.take_back(numTrailingDims)))
3201 assert(all_of(srcShape.drop_back(numTrailingDims),
3202 [](
int64_t E) { return E == 1; }) &&
3203 all_of(shapecastShape.drop_back(numTrailingDims),
3204 [](
int64_t E) { return E == 1; }) &&
3205 "ill-formed shape_cast");
3207 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3211OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3212 if (getSourceType() == getResultVectorType())
3217 if (!adaptor.getSource())
3219 auto vectorType = getResultVectorType();
3220 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3221 if (vectorType.getElementType() != attr.getType())
3225 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3226 if (vectorType.getElementType() != attr.getType())
3230 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3240struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3243 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3244 PatternRewriter &rewriter)
const override {
3245 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3249 broadcastOp.getResultVectorType(),
3250 srcBroadcast.getSource());
3263struct BroadcastToShapeCast final
3264 :
public OpRewritePattern<vector::BroadcastOp> {
3266 LogicalResult matchAndRewrite(vector::BroadcastOp
broadcast,
3267 PatternRewriter &rewriter)
const override {
3269 auto sourceType = dyn_cast<VectorType>(
broadcast.getSourceType());
3272 broadcast,
"source is a scalar, shape_cast doesn't support scalar");
3276 if (sourceType.getNumElements() != outType.getNumElements()) {
3278 broadcast,
"broadcast to a greater number of elements");
3288void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3289 MLIRContext *context) {
3290 results.
add<BroadcastFolder, BroadcastToShapeCast>(context);
3297LogicalResult ShuffleOp::verify() {
3298 VectorType resultType = getResultVectorType();
3299 VectorType v1Type = getV1VectorType();
3300 VectorType v2Type = getV2VectorType();
3302 int64_t resRank = resultType.getRank();
3303 int64_t v1Rank = v1Type.getRank();
3304 int64_t v2Rank = v2Type.getRank();
3305 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3306 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3307 if (!wellFormed0DCase && !wellFormedNDCase)
3311 for (int64_t r = 1; r < v1Rank; ++r) {
3312 int64_t resDim = resultType.getDimSize(r);
3313 int64_t v1Dim = v1Type.getDimSize(r);
3314 int64_t v2Dim = v2Type.getDimSize(r);
3315 if (resDim != v1Dim || v1Dim != v2Dim)
3319 ArrayRef<int64_t> mask = getMask();
3320 int64_t maskLength = mask.size();
3321 if (maskLength <= 0)
3323 if (maskLength != resultType.getDimSize(0))
3326 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3327 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3328 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3330 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3336ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3337 ShuffleOp::Adaptor adaptor,
3338 SmallVectorImpl<Type> &inferredReturnTypes) {
3339 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3343 auto v1Rank = v1Type.getRank();
3346 SmallVector<int64_t, 4> shape;
3347 shape.reserve(v1Rank);
3348 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3351 llvm::append_range(shape, v1Type.getShape().drop_front());
3352 inferredReturnTypes.push_back(
3353 VectorType::get(shape, v1Type.getElementType()));
3357template <
typename T>
3360 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3361 return value == expected++;
3368 auto v1Type = op.getV1VectorType();
3369 auto v2Type = op.getV2VectorType();
3370 auto mask = op.getMask();
3383 if (!isV1Poison && !isV2Poison)
3386 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3387 bool changed =
false;
3389 for (
int64_t &idx : newMask) {
3390 if (idx == ShuffleOp::kPoisonIndex)
3392 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3393 idx = ShuffleOp::kPoisonIndex;
3401 op.setMask(newMask);
3402 return op.getResult();
3411 return ub::PoisonAttr::get(context);
3418 auto v1Type = op.getV1VectorType();
3419 if (v1Type.getRank() != 1)
3431 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3434 v2Elements = to_vector(v2DenseAttr.getValues<
Attribute>());
3435 poisonElement = v2Elements[0];
3438 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3441 v1Elements = to_vector(v1DenseAttr.getValues<
Attribute>());
3442 poisonElement = v1Elements[0];
3447 int64_t v1Size = v1Type.getDimSize(0);
3448 for (
int64_t maskIdx : mask) {
3451 if (maskIdx == ShuffleOp::kPoisonIndex) {
3452 indexedElm = poisonElement;
3454 if (maskIdx < v1Size)
3455 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3457 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3460 results.push_back(indexedElm);
3466OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3467 auto v1Type = getV1VectorType();
3469 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3470 "Vector shuffle does not support scalable vectors");
3474 if (v1Type.getRank() == 0)
3482 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3483 if (!v1Attr || !v2Attr)
3498struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3501 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3502 PatternRewriter &rewriter)
const override {
3503 VectorType v1VectorType = shuffleOp.getV1VectorType();
3504 ArrayRef<int64_t> mask = shuffleOp.getMask();
3505 if (v1VectorType.getRank() > 0)
3507 if (mask.size() != 1)
3509 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3527static Value getScalarSplatSource(Value value) {
3533 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3540 if (isa<VectorType>(
broadcast.getSourceType()))
3548class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3552 LogicalResult matchAndRewrite(ShuffleOp op,
3553 PatternRewriter &rewriter)
const override {
3554 Value splat = getScalarSplatSource(op.getV1());
3555 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3565class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3569 LogicalResult matchAndRewrite(ShuffleOp op,
3570 PatternRewriter &rewriter)
const override {
3571 VectorType resultType = op.getResultVectorType();
3572 if (resultType.isScalable())
3574 op,
"ShuffleOp can't represent a scalable interleave");
3576 if (resultType.getRank() != 1)
3578 op,
"ShuffleOp can't represent an n-D interleave");
3580 VectorType sourceType = op.getV1VectorType();
3581 if (sourceType != op.getV2VectorType() ||
3582 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3584 op,
"ShuffleOp types don't match an interleave");
3587 ArrayRef<int64_t> shuffleMask = op.getMask();
3588 int64_t resultVectorSize = resultType.getNumElements();
3589 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3590 int64_t maskValueA = shuffleMask[i * 2];
3591 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3592 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3594 "ShuffleOp mask not interleaving");
3610class FoldUnusedShuffleOperand final :
public OpRewritePattern<ShuffleOp> {
3614 LogicalResult matchAndRewrite(ShuffleOp op,
3615 PatternRewriter &rewriter)
const override {
3617 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3618 return mask == ShuffleOp::kPoisonIndex;
3625 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3628 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3637 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3638 ? op.getV1VectorType().getDimSize(0)
3640 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3641 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3643 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3647 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3648 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3650 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3658void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3659 MLIRContext *context) {
3660 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3661 FoldUnusedShuffleOperand>(context);
3668void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3670 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3673void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3674 Value source, Value dest) {
3675 auto vectorTy = cast<VectorType>(dest.
getType());
3676 build(builder,
result, source, dest,
3677 SmallVector<int64_t>(vectorTy.getRank(), 0));
3680void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3681 Value source, Value dest, int64_t position) {
3682 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3685void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3686 Value source, Value dest, OpFoldResult position) {
3687 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3690void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3691 Value source, Value dest,
3692 ArrayRef<int64_t> position) {
3693 SmallVector<OpFoldResult> posVals;
3694 posVals.reserve(position.size());
3695 llvm::transform(position, std::back_inserter(posVals),
3697 build(builder,
result, source, dest, posVals);
3700void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3701 Value source, Value dest,
3702 ArrayRef<OpFoldResult> position) {
3703 SmallVector<int64_t> staticPos;
3704 SmallVector<Value> dynamicPos;
3706 build(builder,
result, source, dest, dynamicPos,
3710LogicalResult InsertOp::verify() {
3711 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3712 if (srcTy.getRank() == 0)
3714 "expected a scalar instead of a 0-d vector as the source operand");
3716 SmallVector<OpFoldResult> position = getMixedPosition();
3717 auto destVectorType = getDestVectorType();
3718 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3720 "expected position attribute of rank no greater than dest vector rank");
3721 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3722 if (srcVectorType &&
3723 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3724 static_cast<unsigned>(destVectorType.getRank())))
3725 return emitOpError(
"expected position attribute rank + source rank to "
3726 "match dest vector rank");
3727 if (!srcVectorType &&
3728 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3730 "expected position attribute rank to match the dest vector rank");
3731 for (
auto [idx, pos] : llvm::enumerate(position)) {
3732 if (
auto attr = dyn_cast<Attribute>(pos)) {
3733 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3735 destVectorType.getDimSize(idx))) {
3736 return emitOpError(
"expected position attribute #")
3738 <<
" to be a non-negative integer smaller than the "
3740 "dest vector dimension";
3753 assert(positions.size() <= completePositions.size() &&
3754 "positions size must be less than or equal to destTy rank");
3755 copy(positions, completePositions.begin());
3763class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3767 LogicalResult matchAndRewrite(InsertOp insertOp,
3768 PatternRewriter &rewriter)
const override {
3770 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3771 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3772 srcVecType.getNumElements())
3775 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3781class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3785 LogicalResult matchAndRewrite(InsertOp op,
3786 PatternRewriter &rewriter)
const override {
3788 Value splat = getScalarSplatSource(op.getValueToStore());
3789 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3817class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3820 LogicalResult matchAndRewrite(InsertOp op,
3821 PatternRewriter &rewriter)
const override {
3823 VectorType destTy = op.getDestVectorType();
3824 if (destTy.isScalable())
3827 for (Operation *user : op.getResult().getUsers())
3828 if (
auto insertOp = dyn_cast<InsertOp>(user))
3829 if (insertOp.getDest() == op.getResult())
3832 InsertOp currentOp = op;
3833 SmallVector<InsertOp> chainInsertOps;
3836 if (currentOp.hasDynamicPosition())
3839 chainInsertOps.push_back(currentOp);
3840 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3843 if (currentOp && !currentOp->hasOneUse())
3847 int64_t vectorSize = destTy.getNumElements();
3848 int64_t initializedCount = 0;
3849 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3850 SmallVector<int64_t> pendingInsertPos;
3851 SmallVector<int64_t> pendingInsertSize;
3852 SmallVector<Value> pendingInsertValues;
3854 for (
auto insertOp : chainInsertOps) {
3856 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3860 int64_t insertBeginPosition =
3865 int64_t insertSize = 1;
3866 if (
auto srcVectorType =
3867 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3868 insertSize = srcVectorType.getNumElements();
3870 assert(insertBeginPosition + insertSize <= vectorSize &&
3871 "insert would overflow the vector");
3873 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3874 insertBeginPosition + insertSize)) {
3875 if (initializedDestIdxs[index])
3877 initializedDestIdxs[index] =
true;
3883 pendingInsertPos.push_back(insertBeginPosition);
3884 pendingInsertSize.push_back(insertSize);
3885 pendingInsertValues.push_back(insertOp.getValueToStore());
3887 if (initializedCount == vectorSize)
3892 if (initializedCount != vectorSize)
3895 SmallVector<Value> elements(vectorSize);
3896 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3897 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3898 pendingInsertValues))) {
3899 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3901 if (!srcVectorType) {
3902 elements[insertBeginPosition] = valueToStore;
3906 Repeated<Type> elementToInsertTypes(insertSize,
3907 srcVectorType.getElementType());
3909 auto elementsToInsert = vector::ToElementsOp::create(
3910 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3911 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3912 elements[insertBeginPosition + linearIdx] =
3913 elementsToInsert.getResult(linearIdx);
3927 int64_t maxVectorSizeFoldThreshold) {
3928 if (insertOp.hasDynamicPosition())
3931 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3939 VectorType destTy = insertOp.getDestVectorType();
3940 if (destTy.isScalable())
3944 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3945 !insertOp->hasOneUse())
3950 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3957 Type destEltType = destTy.getElementType();
3961 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3962 for (
auto value : denseSource.getValues<
Attribute>())
3968 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3969 copy(insertedValues, allValues.begin() + insertBeginPosition);
3978 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3982 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3985 insertOp.
setOperand(1, destInsert.getDest());
3986 return insertOp.getResult();
3989void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3990 MLIRContext *context) {
3991 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3992 InsertChainFullyInitialized>(context);
3995OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3998 constexpr int64_t vectorSizeFoldThreshold = 256;
4002 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
4003 return getValueToStore();
4007 SmallVector<Value> operands = {getValueToStore(), getDest()};
4013 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4016 *
this, adaptor.getValueToStore(), adaptor.getDest(),
4017 vectorSizeFoldThreshold)) {
4021 return inplaceFolded;
4028void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4029 Value source, Value dest,
4030 ArrayRef<int64_t> offsets,
4031 ArrayRef<int64_t> strides) {
4032 result.addOperands({source, dest});
4036 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
4038 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
4043template <
typename OpType>
4047 StringRef attrName) {
4048 if (arrayAttr.size() >
shape.size())
4049 return op.emitOpError(
"expected ")
4050 << attrName <<
" attribute of rank no greater than vector rank";
4057template <
typename OpType>
4061 bool halfOpen =
true) {
4062 for (
auto attr : arrayAttr) {
4063 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4067 if (val < min || val >= upper)
4068 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
4069 <<
min <<
", " << upper <<
")";
4077template <
typename OpType>
4082 for (
auto [
index, attrDimPair] :
4083 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
4084 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4088 if (val < min || val >=
max)
4089 return op.emitOpError(
"expected ")
4090 << attrName <<
" dimension " <<
index <<
" to be confined to ["
4091 <<
min <<
", " <<
max <<
")";
4101template <
typename OpType>
4106 assert(arrayAttr1.size() <=
shape.size());
4107 assert(arrayAttr2.size() <=
shape.size());
4108 for (
auto [
index, it] :
4109 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
4110 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4111 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4115 if (val1 + val2 < 0 || val1 + val2 >=
max)
4116 return op.emitOpError(
"expected sum(")
4117 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
4118 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
4126 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4128 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4131LogicalResult InsertStridedSliceOp::verify() {
4132 auto sourceVectorType = getSourceVectorType();
4133 auto destVectorType = getDestVectorType();
4134 auto offsets = getOffsetsAttr();
4135 auto strides = getStridesAttr();
4136 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
4138 "expected offsets of same size as destination vector rank");
4139 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
4140 return emitOpError(
"expected strides of same size as source vector rank");
4141 if (sourceVectorType.getRank() > destVectorType.getRank())
4143 "expected source rank to be no greater than destination rank");
4145 auto sourceShape = sourceVectorType.getShape();
4146 auto destShape = destVectorType.getShape();
4147 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4148 destShape.size() - sourceShape.size(), 0);
4149 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4150 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4151 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4160 offName,
"source vector shape",
4164 unsigned rankDiff = destShape.size() - sourceShape.size();
4165 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4166 if (sourceVectorType.getScalableDims()[idx] !=
4167 destVectorType.getScalableDims()[idx + rankDiff]) {
4168 return emitOpError(
"mismatching scalable flags (at source vector idx=")
4171 if (sourceVectorType.getScalableDims()[idx]) {
4172 auto sourceSize = sourceShape[idx];
4173 auto destSize = destShape[idx + rankDiff];
4174 if (sourceSize != destSize) {
4177 << (
" to match the corresponding base size from the input "
4179 << sourceSize << (
" vs ") << destSize << (
")");
4189class FoldInsertStridedSliceSplat final
4190 :
public OpRewritePattern<InsertStridedSliceOp> {
4194 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4195 PatternRewriter &rewriter)
const override {
4197 auto dst = insertStridedSliceOp.getDest();
4198 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4199 if (!splat || getScalarSplatSource(dst) != splat)
4202 rewriter.
replaceOp(insertStridedSliceOp, dst);
4209class FoldInsertStridedSliceOfExtract final
4210 :
public OpRewritePattern<InsertStridedSliceOp> {
4214 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4215 PatternRewriter &rewriter)
const override {
4216 auto extractStridedSliceOp =
4217 insertStridedSliceOp.getValueToStore()
4218 .getDefiningOp<vector::ExtractStridedSliceOp>();
4220 if (!extractStridedSliceOp)
4223 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4227 if (extractStridedSliceOp.getStrides() !=
4228 insertStridedSliceOp.getStrides() ||
4229 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4232 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4239class InsertStridedSliceConstantFolder final
4240 :
public OpRewritePattern<InsertStridedSliceOp> {
4246 static constexpr int64_t vectorSizeFoldThreshold = 256;
4248 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4249 PatternRewriter &rewriter)
const override {
4253 Attribute vectorDestCst;
4257 VectorType destTy = destVector.getType();
4258 if (destTy.isScalable())
4262 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4263 !destVector.hasOneUse())
4267 Attribute sourceCst;
4277 if (op.hasNonUnitStrides())
4280 VectorType sliceVecTy = sourceValue.getType();
4281 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4282 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4283 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
4284 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
4292 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4293 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4294 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4295 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4296 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4297 MutableArrayRef<int64_t> currSlicePosition(
4298 currDestPosition.begin() + rankDifference, currDestPosition.end());
4299 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4302 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
4303 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
4304 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4305 "Invalid slice element");
4306 newValues[linearizedPosition] = *sliceValuesIt;
4319void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4320 RewritePatternSet &results, MLIRContext *context) {
4321 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4322 InsertStridedSliceConstantFolder>(context);
4325OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4326 if (getSourceVectorType() == getDestVectorType())
4327 return getValueToStore();
4336void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4337 Value
lhs, Value
rhs, Value acc) {
4342void OuterProductOp::print(OpAsmPrinter &p) {
4343 p <<
" " << getLhs() <<
", " << getRhs();
4345 p <<
", " << getAcc();
4348 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4351ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4352 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4359 if (operandsInfo.size() < 2)
4361 "expected at least 2 operands");
4362 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4363 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4366 "expected vector type for operand #1");
4370 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4371 vRHS.getScalableDims()[0]};
4372 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4373 vLHS.getElementType(), scalableDimsRes);
4376 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4377 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4381 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4382 result.attributes.append(
4383 OuterProductOp::getKindAttrName(
result.name),
4384 CombiningKindAttr::get(
result.getContext(),
4385 OuterProductOp::getDefaultKind()));
4391 (operandsInfo.size() > 2 &&
4396LogicalResult OuterProductOp::verify() {
4397 Type tRHS = getOperandTypeRHS();
4398 VectorType vLHS = getOperandVectorTypeLHS(),
4399 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4400 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4402 if (vLHS.getRank() != 1)
4403 return emitOpError(
"expected 1-d vector for operand #1");
4407 if (vRHS.getRank() != 1)
4408 return emitOpError(
"expected 1-d vector for operand #2");
4409 if (vRES.getRank() != 2)
4411 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4412 return emitOpError(
"expected #1 operand dim to match result dim #1");
4413 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4414 return emitOpError(
"expected #2 operand dim to match result dim #2");
4415 if (vLHS.isScalable() && !vRHS.isScalable()) {
4419 "expected either both or only #2 operand dim to be scalable");
4423 if (vRES.getRank() != 1)
4425 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4426 return emitOpError(
"expected #1 operand dim to match result dim #1");
4429 if (vACC && vACC != vRES)
4430 return emitOpError(
"expected operand #3 of same type as result type");
4432 if (!getKindAttr()) {
4433 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4434 "'vector.kind<add>')");
4439 return emitOpError(
"unsupported outerproduct type");
4448Type OuterProductOp::getExpectedMaskType() {
4449 auto vecType = this->getResultVectorType();
4450 return VectorType::get(vecType.getShape(),
4451 IntegerType::get(vecType.getContext(), 1),
4452 vecType.getScalableDims());
4466 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4468 shape.reserve(vectorType.getRank());
4470 for (
unsigned e = offsets.size(); idx < e; ++idx)
4471 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4472 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4473 shape.push_back(vectorType.getShape()[idx]);
4475 return VectorType::get(
shape, vectorType.getElementType(),
4476 vectorType.getScalableDims());
4479void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4480 Value source, ArrayRef<int64_t> offsets,
4481 ArrayRef<int64_t> sizes,
4482 ArrayRef<int64_t> strides) {
4483 result.addOperands(source);
4489 offsetsAttr, sizesAttr, stridesAttr));
4490 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4492 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4494 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4498LogicalResult ExtractStridedSliceOp::verify() {
4499 auto type = getSourceVectorType();
4500 auto offsets = getOffsetsAttr();
4501 auto sizes = getSizesAttr();
4502 auto strides = getStridesAttr();
4503 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4505 "expected offsets, sizes and strides attributes of same size");
4507 auto shape = type.getShape();
4508 auto offName = getOffsetsAttrName();
4509 auto sizesName = getSizesAttrName();
4510 auto stridesName = getStridesAttrName();
4526 shape, offName, sizesName,
4531 offsets, sizes, strides);
4532 if (getResult().
getType() != resultType)
4533 return emitOpError(
"expected result type to be ") << resultType;
4535 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4536 if (type.getScalableDims()[idx]) {
4537 auto inputDim = type.getShape()[idx];
4538 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4539 if (inputDim != inputSize)
4542 << (
" to match the corresponding base size from the input "
4544 << inputSize << (
" vs ") << inputDim << (
")");
4557 auto getElement = [](
ArrayAttr array,
int idx) {
4558 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4560 ArrayAttr extractOffsets = op.getOffsets();
4563 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4565 if (op.getSourceVectorType().getRank() !=
4566 insertOp.getSourceVectorType().getRank())
4568 ArrayAttr insertOffsets = insertOp.getOffsets();
4569 ArrayAttr insertStrides = insertOp.getStrides();
4572 if (extractOffsets.size() > insertOffsets.size())
4574 bool patialoverlap =
false;
4575 bool disjoint =
false;
4577 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4578 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4580 int64_t start = getElement(insertOffsets, dim);
4581 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4582 int64_t offset = getElement(extractOffsets, dim);
4583 int64_t size = getElement(extractSizes, dim);
4585 if (start <= offset && offset < end) {
4588 if (offset + size > end)
4589 patialoverlap =
true;
4590 offsetDiffs.push_back(offset - start);
4597 if (!disjoint && !patialoverlap) {
4598 op.setOperand(insertOp.getValueToStore());
4601 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4607 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4622 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4627 if (op.hasNonUnitStrides())
4630 VectorType sourceVecTy = op.getSourceVectorType();
4634 VectorType sliceVecTy = op.getType();
4636 int64_t rank = sliceVecTy.getRank();
4648 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4650 sliceValues.reserve(sliceVecTy.getNumElements());
4654 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4656 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4657 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4659 assert(
static_cast<int64_t>(sliceValues.size()) ==
4660 sliceVecTy.getNumElements() &&
4661 "Invalid number of slice elements");
4665OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4666 if (getSourceVectorType() == getResult().
getType())
4673 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4680void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4702class StridedSliceFolder final
4703 :
public OpRewritePattern<ExtractStridedSliceOp> {
4705 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4707 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4708 PatternRewriter &rewriter)
const override {
4709 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4713 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4716 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4717 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4718 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4719 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4721 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4722 SmallVector<int64_t> combinedOffsets(newRank, 0);
4723 SmallVector<int64_t> combinedSizes(newRank);
4724 ArrayRef<int64_t> firstSourceShape =
4725 firstOp.getSourceVectorType().getShape();
4726 for (
unsigned i = 0; i < newRank; ++i) {
4727 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4728 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4729 combinedOffsets[i] = off1 + off2;
4731 if (i < secondSizes.size()) {
4732 combinedSizes[i] = secondSizes[i];
4733 }
else if (i < firstSizes.size()) {
4734 combinedSizes[i] = firstSizes[i];
4736 combinedSizes[i] = firstSourceShape[i];
4740 SmallVector<int64_t> combinedStrides(newRank, 1);
4742 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4760class StridedSliceCreateMaskFolder final
4761 :
public OpRewritePattern<ExtractStridedSliceOp> {
4765 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4766 PatternRewriter &rewriter)
const override {
4767 Location loc = extractStridedSliceOp.getLoc();
4771 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4775 if (extractStridedSliceOp.hasNonUnitStrides())
4778 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4780 SmallVector<int64_t> sliceOffsets;
4783 SmallVector<int64_t> sliceSizes;
4787 SmallVector<Value> sliceMaskDimSizes;
4788 sliceMaskDimSizes.reserve(maskDimSizes.size());
4792 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4793 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4797 IntegerAttr offsetAttr =
4799 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4800 Value sliceMaskDimSize =
4801 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4802 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4807 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4811 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4819class StridedSliceConstantMaskFolder final
4820 :
public OpRewritePattern<ExtractStridedSliceOp> {
4824 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4825 PatternRewriter &rewriter)
const override {
4828 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4829 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4830 if (!constantMaskOp)
4833 if (extractStridedSliceOp.hasNonUnitStrides())
4836 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4838 SmallVector<int64_t> sliceOffsets;
4841 SmallVector<int64_t> sliceSizes;
4845 SmallVector<int64_t> sliceMaskDimSizes;
4846 sliceMaskDimSizes.reserve(maskDimSizes.size());
4847 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4848 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4849 int64_t sliceMaskDimSize = std::max(
4850 static_cast<int64_t
>(0),
4851 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4852 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4855 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4856 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4857 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4860 if (llvm::is_contained(sliceMaskDimSizes, 0))
4861 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4866 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4874class StridedSliceBroadcast final
4875 :
public OpRewritePattern<ExtractStridedSliceOp> {
4879 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4880 PatternRewriter &rewriter)
const override {
4886 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4887 auto dstVecType = llvm::cast<VectorType>(op.getType());
4888 unsigned dstRank = dstVecType.getRank();
4889 unsigned rankDiff = dstRank - srcRank;
4893 bool needsSlice =
false;
4894 for (
unsigned i = 0; i < srcRank; i++) {
4895 if (srcVecType.getDimSize(i) != 1 &&
4896 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4903 SmallVector<int64_t> offsets =
4905 SmallVector<int64_t> sizes =
4907 for (
unsigned i = 0; i < srcRank; i++) {
4908 if (srcVecType.getDimSize(i) == 1) {
4916 source = ExtractStridedSliceOp::create(
4917 rewriter, op->getLoc(), source, offsets, sizes,
4926class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4930 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4931 PatternRewriter &rewriter)
const override {
4933 Value splat = getScalarSplatSource(op.getSource());
4957class ContiguousExtractStridedSliceToExtract final
4958 :
public OpRewritePattern<ExtractStridedSliceOp> {
4962 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4963 PatternRewriter &rewriter)
const override {
4964 if (op.hasNonUnitStrides())
4966 Value source = op.getOperand();
4967 auto sourceType = cast<VectorType>(source.
getType());
4968 if (sourceType.isScalable() || sourceType.getRank() == 0)
4977 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4978 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4985 if (numOffsets == 0)
4990 if (numOffsets == sourceType.getRank() &&
4991 static_cast<int>(sizes.size()) == sourceType.getRank())
4995 for (
int i = 0; i < numOffsets; ++i) {
5003 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
5004 sizes[numOffsets] == 1) {
5009 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5010 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5019void ExtractStridedSliceOp::getCanonicalizationPatterns(
5020 RewritePatternSet &results, MLIRContext *context) {
5023 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5024 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5025 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5034void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5035 VectorType vectorType, Value source,
5037 AffineMapAttr permutationMapAttr,
5040 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5042 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5043 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
5044 *padding, Value(), inBoundsAttr);
5048void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5049 VectorType vectorType, Value source,
5051 AffineMap permutationMap,
5052 std::optional<ArrayRef<bool>> inBounds) {
5053 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5054 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5057 SmallVector<bool>(vectorType.getRank(),
false));
5058 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5060 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5061 build(builder,
result, vectorType, source,
indices, *padding,
5062 permutationMapAttr, inBoundsAttr);
5066void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
5067 VectorType vectorType, Value source,
5069 std::optional<ArrayRef<bool>> inBounds) {
5071 llvm::cast<ShapedType>(source.
getType()), vectorType);
5072 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5073 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5076 SmallVector<bool>(vectorType.getRank(),
false));
5077 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5079 padding = ub::PoisonOp::create(builder,
result.location, elemType);
5080 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
5082 Value(), inBoundsAttr);
5085template <
typename EmitFun>
5089 for (
auto expr : permutationMap.
getResults()) {
5090 auto dim = dyn_cast<AffineDimExpr>(expr);
5091 auto zero = dyn_cast<AffineConstantExpr>(expr);
5093 if (zero.getValue() != 0) {
5095 "requires a projected permutation_map (at most one dim or the zero "
5096 "constant can appear in each result)");
5101 return emitOpError(
"requires a projected permutation_map (at most one "
5102 "dim or the zero constant can appear in each result)");
5104 if (seen[dim.getPosition()]) {
5106 "requires a permutation_map that is a permutation (found one dim "
5107 "used more than once)");
5109 seen[dim.getPosition()] =
true;
5116 VectorType vectorType, VectorType maskType,
5117 VectorType inferredMaskType,
AffineMap permutationMap,
5119 if (op->hasAttr(
"masked")) {
5120 return op->emitOpError(
"masked attribute has been removed. "
5121 "Use in_bounds instead.");
5124 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5125 return op->emitOpError(
5126 "requires source to be a memref or ranked tensor type");
5128 auto elementType = shapedType.getElementType();
5130 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5132 unsigned sourceVecSize =
5134 vectorElementType.getShape().back();
5135 unsigned resultVecSize =
5137 vectorType.getShape().back();
5138 if (resultVecSize % sourceVecSize != 0)
5139 return op->emitOpError(
5140 "requires the bitwidth of the minor 1-D vector to be an integral "
5141 "multiple of the bitwidth of the minor 1-D vector of the source");
5143 unsigned sourceVecEltRank = vectorElementType.getRank();
5144 unsigned resultVecRank = vectorType.getRank();
5145 if (sourceVecEltRank > resultVecRank)
5146 return op->emitOpError(
5147 "requires source vector element and vector result ranks to match.");
5148 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5151 return op->emitOpError(
"requires a permutation_map with result dims of "
5152 "the same rank as the vector type");
5155 return op->emitOpError(
"does not support masks with vector element type");
5158 unsigned minorSize =
5159 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5160 unsigned resultVecSize =
5163 return op->emitOpError(
5164 "requires the bitwidth of the minor 1-D vector to be an integral "
5165 "multiple of the bitwidth of the source element type");
5169 return op->emitOpError(
"requires a permutation_map with result dims of "
5170 "the same rank as the vector type");
5174 return op->emitOpError(
"requires permutation_map without symbols");
5176 if (permutationMap.
getNumInputs() != shapedType.getRank())
5177 return op->emitOpError(
"requires a permutation_map with input dims of the "
5178 "same rank as the source type");
5180 if (maskType && maskType != inferredMaskType)
5181 return op->emitOpError(
"inferred mask type (")
5182 << inferredMaskType <<
") and mask operand type (" << maskType
5186 return op->emitOpError(
"expects the in_bounds attr of same rank "
5187 "as permutation_map results: ")
5188 << AffineMapAttr::get(permutationMap)
5189 <<
" vs inBounds of size: " << inBounds.size();
5196 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5197 if (op.getPermutationMap().isMinorIdentity())
5198 elidedAttrs.push_back(op.getPermutationMapAttrName());
5200 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
5201 elidedAttrs.push_back(op.getInBoundsAttrName());
5205void TransferReadOp::print(OpAsmPrinter &p) {
5208 p <<
", " << getMask();
5215 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
5217 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
5222 if (maskShape.empty())
5223 maskShape.push_back(1);
5228 return VectorType::get(maskShape, i1Type, scalableDims);
5245 if (hasMask.succeeded()) {
5252 if (types.size() != 2)
5253 return parser.
emitError(typesLoc,
"requires two types");
5255 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5256 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5257 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5258 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5260 return parser.
emitError(typesLoc,
"requires vector type");
5261 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
5265 if (shapedType.getRank() <
5268 "expected a custom permutation_map when "
5269 "rank(source) != rank(destination)");
5271 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5273 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5275 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
5276 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5277 if (!inBoundsAttr) {
5278 result.addAttribute(inBoundsAttrName,
5287 if (hasMask.succeeded()) {
5288 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5290 maskInfo.
location,
"does not support masks with vector element type");
5293 "expected the same rank for the vector and the "
5294 "results of the permutation map");
5302 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5304 {1, static_cast<int32_t>(indexInfo.size()), 1,
5305 static_cast<int32_t>(hasMask.succeeded())}));
5309LogicalResult TransferReadOp::verify() {
5311 ShapedType shapedType = getShapedType();
5313 VectorType maskType = getMaskType();
5314 auto paddingType = getPadding().getType();
5315 auto permutationMap = getPermutationMap();
5316 VectorType inferredMaskType =
5319 auto sourceElementType = shapedType.getElementType();
5321 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5322 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5325 shapedType, vectorType, maskType,
5326 inferredMaskType, permutationMap, getInBounds())))
5329 if (
auto sourceVectorElementType =
5330 llvm::dyn_cast<VectorType>(sourceElementType)) {
5333 if (sourceVectorElementType != paddingType)
5335 "requires source element type and padding type to match.");
5339 if (!VectorType::isValidElementType(paddingType))
5340 return emitOpError(
"requires valid padding vector elemental type");
5343 if (paddingType != sourceElementType)
5345 "requires formal padding and source of the same elemental type");
5356Type TransferReadOp::getExpectedMaskType() {
5363VectorType TransferReadOp::getVectorType() {
5364 return cast<VectorType>(getVector().
getType());
5367template <
typename TransferOp>
5371 if (op.getShapedType().isDynamicDim(indicesIdx))
5375 if (!cstOp.has_value())
5378 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5379 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5381 return cstOp.value() + vectorSize <= sourceSize;
5384template <
typename TransferOp>
5388 if (op.getTransferRank() == 0)
5391 bool changed =
false;
5393 newInBounds.reserve(op.getTransferRank());
5398 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5400 if (op.isDimInBounds(i)) {
5401 newInBounds.push_back(
true);
5406 bool inBounds =
false;
5407 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5410 dimExpr.getPosition());
5411 nonBcastDims.push_back(i);
5414 newInBounds.push_back(inBounds);
5416 changed |= inBounds;
5422 bool allNonBcastDimsInBounds = llvm::all_of(
5423 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5424 if (allNonBcastDimsInBounds) {
5426 changed |= !newInBounds[idx];
5427 newInBounds[idx] =
true;
5435 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5439template <
typename TransferOp>
5441 auto mask = op.getMask();
5448 op.getMaskMutable().clear();
5462static Value foldRAW(TransferReadOp readOp) {
5463 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5465 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5468 return defWrite.getVector();
5470 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5471 cast<VectorTransferOpInterface>(readOp.getOperation())))
5473 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5478OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5479 if (Value vec = foldRAW(*
this))
5490 return OpFoldResult();
5493std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5497void TransferReadOp::getEffects(
5498 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5500 if (llvm::isa<MemRefType>(getShapedType()))
5501 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5502 SideEffects::DefaultResource::get());
5506 if (hasPureTensorSemantics())
5513static AffineMap inverseWithUnusedDims(AffineMap map) {
5515 "expected a projected permutation map");
5520 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5550struct TransferReadAfterWriteToBroadcast
5551 :
public OpRewritePattern<TransferReadOp> {
5554 LogicalResult matchAndRewrite(TransferReadOp readOp,
5555 PatternRewriter &rewriter)
const override {
5556 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5560 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5564 if (readOp.getMask() || defWrite.getMask())
5567 if (readOp.getIndices() != defWrite.getIndices())
5570 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5574 if (readOp.getTransferChunkAccessed() !=
5575 defWrite.getTransferChunkAccessed())
5582 AffineMap readMap = readOp.getPermutationMap();
5583 AffineMap writeMap = defWrite.getPermutationMap();
5584 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5585 AffineMap composedMap = readMap.
compose(invWriteMap);
5599 int64_t numBroadcastedDims = broadcastedDims.size();
5600 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5602 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5603 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5604 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5605 invPerm[effectiveDim] = idx;
5610 VectorType readVecTy = readOp.getVectorType();
5612 auto broadcastedVecTy =
5614 readVecTy.getElementType(),
5617 Value vec = defWrite.getVector();
5618 Location loc = readOp.getLoc();
5619 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5626void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5627 MLIRContext *context) {
5628 results.
add<TransferReadAfterWriteToBroadcast>(context);
5631FailureOr<std::optional<SmallVector<Value>>>
5632TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5633 if (!hasPureBufferSemantics())
5644void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5646 AffineMapAttr permutationMapAttr,
5649 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5650 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5651 mask, inBoundsAttr);
5655void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5657 AffineMapAttr permutationMapAttr,
5659 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5660 Value(), inBoundsAttr);
5665void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5667 AffineMap permutationMap,
5668 std::optional<ArrayRef<bool>> inBounds) {
5669 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5671 (inBounds && !inBounds.value().empty())
5674 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5675 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5676 Value(), inBoundsAttr);
5681void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5683 std::optional<ArrayRef<bool>> inBounds) {
5684 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5686 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5687 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5690ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5691 OperationState &
result) {
5694 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5695 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5696 SmallVector<Type, 2> types;
5697 OpAsmParser::UnresolvedOperand maskInfo;
5703 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5708 if (types.size() != 2)
5709 return parser.
emitError(typesLoc,
"requires two types");
5711 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5713 return parser.
emitError(typesLoc,
"requires vector type");
5714 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5715 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5716 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5717 auto permMapAttrName =
5718 TransferWriteOp::getPermutationMapAttrName(
result.name);
5719 auto permMapAttr =
result.attributes.get(permMapAttrName);
5722 if (shapedType.getRank() <
5725 "expected a custom permutation_map when "
5726 "rank(source) != rank(destination)");
5728 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5730 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5732 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5733 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5734 if (!inBoundsAttr) {
5735 result.addAttribute(inBoundsAttrName,
5743 if (hasMask.succeeded()) {
5744 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5746 maskInfo.
location,
"does not support masks with vector element type");
5749 "expected the same rank for the vector and the "
5750 "results of the permutation map");
5756 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5758 {1, 1, static_cast<int32_t>(indexInfo.size()),
5759 static_cast<int32_t>(hasMask.succeeded())}));
5760 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5764void TransferWriteOp::print(OpAsmPrinter &p) {
5767 p <<
", " << getMask();
5772LogicalResult TransferWriteOp::verify() {
5774 ShapedType shapedType = getShapedType();
5776 VectorType maskType = getMaskType();
5777 auto permutationMap = getPermutationMap();
5778 VectorType inferredMaskType =
5782 if (llvm::size(
getIndices()) != shapedType.getRank())
5783 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5787 if (hasBroadcastDim())
5788 return emitOpError(
"should not have broadcast dimensions");
5791 shapedType, vectorType, maskType,
5792 inferredMaskType, permutationMap, getInBounds())))
5805Type TransferWriteOp::getExpectedMaskType() {
5812Value TransferWriteOp::getVector() {
return getOperand(0); }
5813VectorType TransferWriteOp::getVectorType() {
5814 return cast<VectorType>(getValueToStore().
getType());
5837static LogicalResult foldReadInitWrite(TransferWriteOp write,
5838 ArrayRef<Attribute>,
5839 SmallVectorImpl<OpFoldResult> &results) {
5841 if (write.getTransferRank() == 0)
5843 auto rankedTensorType =
5844 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5846 if (!rankedTensorType)
5849 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5853 if (read.getTransferRank() == 0)
5856 if (!read.getPermutationMap().isMinorIdentity() ||
5857 !write.getPermutationMap().isMinorIdentity())
5860 if (read.getTransferRank() != write.getTransferRank())
5863 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5866 if (read.getMask() || write.getMask())
5869 if (read.getBase().getType() != rankedTensorType)
5872 if (read.getVectorType() != write.getVectorType())
5875 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5878 auto isNotConstantZero = [](Value v) {
5880 return !cstOp.has_value() || cstOp.value() != 0;
5882 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5883 llvm::any_of(write.getIndices(), isNotConstantZero))
5886 results.push_back(read.getBase());
5890static bool checkSameValueWAR(vector::TransferReadOp read,
5891 vector::TransferWriteOp write) {
5892 return read.getBase() == write.getBase() &&
5893 read.getIndices() == write.getIndices() &&
5894 read.getPermutationMap() == write.getPermutationMap() &&
5895 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5912static LogicalResult foldWAR(TransferWriteOp write,
5913 SmallVectorImpl<OpFoldResult> &results) {
5914 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5916 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5920 if (!checkSameValueWAR(read, write))
5922 results.push_back(read.getBase());
5926LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5927 SmallVectorImpl<OpFoldResult> &results) {
5928 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5930 if (succeeded(foldWAR(*
this, results)))
5942std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5946void TransferWriteOp::getEffects(
5947 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5949 if (llvm::isa<MemRefType>(getShapedType()))
5950 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5951 SideEffects::DefaultResource::get());
5955 if (hasPureTensorSemantics())
5985class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
5988 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5989 PatternRewriter &rewriter)
const override {
5990 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5992 vector::TransferWriteOp writeToModify = writeOp;
5994 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5998 writeToModify.getBaseMutable().assign(defWrite.getBase());
6003 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6004 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6008 if (!defWrite->hasOneUse())
6010 writeToModify = defWrite;
6011 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6040struct SwapExtractSliceOfTransferWrite
6041 :
public OpRewritePattern<tensor::InsertSliceOp> {
6045 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6046 PatternRewriter &rewriter)
const override {
6047 if (!insertOp.hasUnitStride())
6050 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6051 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6053 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6054 if (!transferOp || !transferOp->hasOneUse())
6059 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6061 "use-def chain is rank-reducing");
6065 if (!extractOp.hasZeroOffset()) {
6067 "ExtractSliceOp has non-zero offset");
6071 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6072 return getConstantIntValue(value) == static_cast<int64_t>(0);
6075 "TranferWriteOp has non-zero offset");
6079 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6081 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
6084 for (
auto [insertSize, extractSize] :
6085 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6088 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
6093 assert(transferOp.getVectorType().hasStaticShape() &&
6094 "expected vector to have a static shape");
6095 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
6097 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6098 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
6100 insertOp,
"TransferWriteOp may not write the full tensor.");
6105 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
6106 auto newExtractOp = tensor::ExtractSliceOp::create(
6107 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6108 insertOp.getDest(), insertOp.getMixedOffsets(),
6109 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6110 auto newTransferWriteOp = TransferWriteOp::create(
6111 rewriter, transferOp.getLoc(), transferOp.getVector(),
6112 newExtractOp.getResult(), transferOp.getIndices(),
6113 transferOp.getPermutationMapAttr(),
6116 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6124void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6125 MLIRContext *context) {
6126 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6129FailureOr<std::optional<SmallVector<Value>>>
6130TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6131 if (!hasPureBufferSemantics())
6141static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6143 MemRefType memRefTy) {
6146 if (!vecTy.isScalable() &&
6147 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6150 if (!memRefTy.isLastDimUnitStride())
6151 return op->
emitOpError(
"most minor memref dim must have unit stride");
6155LogicalResult vector::LoadOp::verify() {
6159 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
6162 if (memRefTy.getRank() < resVecTy.getRank())
6164 "destination memref has lower rank than the result vector");
6167 Type memElemTy = memRefTy.getElementType();
6168 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6169 if (memVecTy != resVecTy)
6170 return emitOpError(
"base memref and result vector types should match");
6171 memElemTy = memVecTy.getElementType();
6174 if (resVecTy.getElementType() != memElemTy)
6175 return emitOpError(
"base and result element types should match");
6176 if (llvm::size(
getIndices()) != memRefTy.getRank())
6177 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6181OpFoldResult LoadOp::fold(FoldAdaptor) {
6184 return OpFoldResult();
6187std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6191FailureOr<std::optional<SmallVector<Value>>>
6192LoadOp::bubbleDownCasts(OpBuilder &builder) {
6201LogicalResult vector::StoreOp::verify() {
6205 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
6208 if (memRefTy.getRank() < valueVecTy.getRank())
6209 return emitOpError(
"source memref has lower rank than the vector to store");
6212 Type memElemTy = memRefTy.getElementType();
6213 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6214 if (memVecTy != valueVecTy)
6216 "base memref and valueToStore vector types should match");
6217 memElemTy = memVecTy.getElementType();
6220 if (valueVecTy.getElementType() != memElemTy)
6221 return emitOpError(
"base and valueToStore element type should match");
6222 if (llvm::size(
getIndices()) != memRefTy.getRank())
6223 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6227LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6228 SmallVectorImpl<OpFoldResult> &results) {
6232std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6236FailureOr<std::optional<SmallVector<Value>>>
6237StoreOp::bubbleDownCasts(OpBuilder &builder) {
6246LogicalResult MaskedLoadOp::verify() {
6247 VectorType maskVType = getMaskVectorType();
6248 VectorType passVType = getPassThruVectorType();
6255 if (llvm::size(
getIndices()) != memType.getRank())
6256 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6257 if (resVType.getShape() != maskVType.getShape())
6258 return emitOpError(
"expected result shape to match mask shape");
6259 if (resVType != passVType)
6260 return emitOpError(
"expected pass_thru of same type as result type");
6265class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
6268 LogicalResult matchAndRewrite(MaskedLoadOp
load,
6269 PatternRewriter &rewriter)
const override {
6281 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
6286void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6287 MLIRContext *context) {
6288 results.
add<MaskedLoadFolder>(context);
6291OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6294 return OpFoldResult();
6297FailureOr<std::optional<SmallVector<Value>>>
6298MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6307LogicalResult MaskedStoreOp::verify() {
6308 VectorType maskVType = getMaskVectorType();
6315 if (llvm::size(
getIndices()) != memType.getRank())
6316 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6317 if (valueVType.getShape() != maskVType.getShape())
6318 return emitOpError(
"expected valueToStore shape to match mask shape");
6323class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6326 LogicalResult matchAndRewrite(MaskedStoreOp store,
6327 PatternRewriter &rewriter)
const override {
6331 store, store.getValueToStore(), store.getBase(), store.getIndices());
6339 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6344void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6345 MLIRContext *context) {
6346 results.
add<MaskedStoreFolder>(context);
6349LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6350 SmallVectorImpl<OpFoldResult> &results) {
6354FailureOr<std::optional<SmallVector<Value>>>
6355MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6364LogicalResult GatherOp::verify() {
6365 VectorType indVType = getIndexVectorType();
6366 VectorType maskVType = getMaskVectorType();
6368 ShapedType baseType = getBaseType();
6370 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6371 return emitOpError(
"requires base to be a memref or ranked tensor type");
6376 if (llvm::size(getOffsets()) != baseType.getRank())
6377 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6378 if (resVType.getShape() != indVType.getShape())
6379 return emitOpError(
"expected result dim to match indices dim");
6380 if (resVType.getShape() != maskVType.getShape())
6381 return emitOpError(
"expected result dim to match mask dim");
6382 if (resVType != getPassThruVectorType())
6383 return emitOpError(
"expected pass_thru of same type as result type");
6384 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6386 "alignment is only supported for memref bases, not tensor bases");
6395Type GatherOp::getExpectedMaskType() {
6396 auto vecType = this->getIndexVectorType();
6397 return VectorType::get(vecType.getShape(),
6398 IntegerType::get(vecType.getContext(), 1),
6399 vecType.getScalableDims());
6402std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6407static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6408 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6409 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6415 DenseIntElementsAttr elements;
6420 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6424class GatherFolder final :
public OpRewritePattern<GatherOp> {
6427 LogicalResult matchAndRewrite(GatherOp gather,
6428 PatternRewriter &rewriter)
const override {
6433 rewriter.
replaceOp(gather, gather.getPassThru());
6438 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6444class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6447 LogicalResult matchAndRewrite(GatherOp op,
6448 PatternRewriter &rewriter)
const override {
6449 if (!isa<MemRefType>(op.getBase().getType()))
6452 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6456 op.getOffsets(), op.getMask(),
6463void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6464 MLIRContext *context) {
6465 results.
add<GatherFolder, FoldContiguousGather>(context);
6468FailureOr<std::optional<SmallVector<Value>>>
6469GatherOp::bubbleDownCasts(OpBuilder &builder) {
6478LogicalResult ScatterOp::verify() {
6479 VectorType indVType = getIndexVectorType();
6480 VectorType maskVType = getMaskVectorType();
6482 ShapedType baseType = getBaseType();
6484 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6485 return emitOpError(
"requires base to be a memref or ranked tensor type");
6490 if (llvm::size(getOffsets()) != baseType.getRank())
6491 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6492 if (valueVType.getShape() != indVType.getShape())
6493 return emitOpError(
"expected valueToStore dim to match indices dim");
6494 if (valueVType.getShape() != maskVType.getShape())
6495 return emitOpError(
"expected valueToStore dim to match mask dim");
6496 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6498 "alignment is only supported for memref bases, not tensor bases");
6503class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6506 LogicalResult matchAndRewrite(ScatterOp scatter,
6507 PatternRewriter &rewriter)
const override {
6508 ShapedType baseType = scatter.getBaseType();
6509 bool isMemRef = isa<MemRefType>(baseType);
6510 if (!isMemRef && !isa<RankedTensorType>(baseType))
6523 rewriter.
replaceOp(scatter, scatter.getBase());
6528 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6534class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6537 LogicalResult matchAndRewrite(ScatterOp op,
6538 PatternRewriter &rewriter)
const override {
6541 if (!isa<MemRefType>(op.getBase().getType()))
6544 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6548 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6554void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6555 MLIRContext *context) {
6556 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6559FailureOr<std::optional<SmallVector<Value>>>
6560ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6569LogicalResult ExpandLoadOp::verify() {
6570 VectorType maskVType = getMaskVectorType();
6571 VectorType passVType = getPassThruVectorType();
6578 if (llvm::size(
getIndices()) != memType.getRank())
6579 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6580 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6581 return emitOpError(
"expected result dim to match mask dim");
6582 if (resVType != passVType)
6583 return emitOpError(
"expected pass_thru of same type as result type");
6588class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6591 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6592 PatternRewriter &rewriter)
const override {
6596 expand, expand.getType(), expand.getBase(), expand.getIndices());
6599 rewriter.
replaceOp(expand, expand.getPassThru());
6604 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6609void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6610 MLIRContext *context) {
6611 results.
add<ExpandLoadFolder>(context);
6614FailureOr<std::optional<SmallVector<Value>>>
6615ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6624LogicalResult CompressStoreOp::verify() {
6625 VectorType maskVType = getMaskVectorType();
6632 if (llvm::size(
getIndices()) != memType.getRank())
6633 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6634 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6635 return emitOpError(
"expected valueToStore dim to match mask dim");
6640class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6643 LogicalResult matchAndRewrite(CompressStoreOp compress,
6644 PatternRewriter &rewriter)
const override {
6648 compress, compress.getValueToStore(), compress.getBase(),
6649 compress.getIndices());
6657 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6662void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6663 MLIRContext *context) {
6664 results.
add<CompressStoreFolder>(context);
6667FailureOr<std::optional<SmallVector<Value>>>
6668CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6677void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6679 setResultRanges(getResult(), argRanges.front());
6682std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6683 return llvm::to_vector<4>(getResultVectorType().
getShape());
6686LogicalResult ShapeCastOp::verify() {
6688 VectorType sourceType = getSourceVectorType();
6689 VectorType resultType = getResultVectorType();
6697 int64_t sourceNElms = sourceType.getNumElements();
6698 int64_t resultNElms = resultType.getNumElements();
6699 if (sourceNElms != resultNElms) {
6700 return emitOpError() <<
"has different number of elements at source ("
6701 << sourceNElms <<
") and result (" << resultNElms
6706 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6707 int64_t resultNScalableDims = resultType.getNumScalableDims();
6708 if (sourceNScalableDims != resultNScalableDims)
6709 return emitOpError() <<
"has different number of scalable dims at source ("
6710 << sourceNScalableDims <<
") and result ("
6711 << resultNScalableDims <<
")";
6720static bool isOrderPreserving(TransposeOp transpose) {
6721 ArrayRef<int64_t> permutation = transpose.getPermutation();
6722 VectorType sourceType = transpose.getSourceVectorType();
6723 ArrayRef<int64_t> inShape = sourceType.getShape();
6724 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6725 auto isNonScalableUnitDim = [&](int64_t dim) {
6726 return inShape[dim] == 1 && !inDimIsScalable[dim];
6728 int64_t current = 0;
6729 for (
auto p : permutation) {
6730 if (!isNonScalableUnitDim(p)) {
6740OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6742 VectorType resultType =
getType();
6745 if (getSource().
getType() == resultType)
6749 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6750 setOperand(precedingShapeCast.getSource());
6755 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6756 if (isOrderPreserving(transpose)) {
6757 setOperand(transpose.getVector());
6765 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6766 if (bcastOp.getSourceType() == resultType)
6767 return bcastOp.getSource();
6771 if (
auto denseAttr =
6772 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6773 return denseAttr.reshape(
getType());
6789static VectorType trimTrailingOneDims(VectorType oldType) {
6790 ArrayRef<int64_t> oldShape = oldType.getShape();
6791 ArrayRef<int64_t> newShape = oldShape;
6793 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6794 ArrayRef<bool> newScalableDims = oldScalableDims;
6796 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6797 newShape = newShape.drop_back(1);
6798 newScalableDims = newScalableDims.drop_back(1);
6803 if (newShape.empty()) {
6804 newShape = oldShape.take_back();
6805 newScalableDims = oldScalableDims.take_back();
6808 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6823class ShapeCastCreateMaskFolderTrailingOneDim final
6824 :
public OpRewritePattern<ShapeCastOp> {
6828 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6829 PatternRewriter &rewriter)
const override {
6830 Value shapeOpSrc = shapeOp->getOperand(0);
6831 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6832 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6833 if (!createMaskOp && !constantMaskOp)
6836 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6837 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6839 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6840 if (newVecType != shapeOpResTy)
6843 auto numDimsToDrop =
6844 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6851 auto maskOperands = createMaskOp.getOperands();
6852 auto numMaskOperands = maskOperands.size();
6855 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6857 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6858 if (!constant || (constant.value() != 1))
6861 SmallVector<Value> newMaskOperands =
6862 maskOperands.drop_back(numDimsToDrop);
6869 if (constantMaskOp) {
6870 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6871 auto numMaskOperands = maskDimSizes.size();
6874 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6876 if (maskDimSizes[i] != 1)
6880 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6893int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6894 ArrayRef<int64_t> dstShape) {
6895 int stretchingFactor = 1;
6896 int numLeadingDims = dstShape.size() - srcShape.size();
6897 for (
int i = 0, e = srcShape.size(); i < e; i++) {
6898 int64_t dstDim = dstShape[numLeadingDims + i];
6899 if (srcShape[i] == 1 && dstDim != 1) {
6900 stretchingFactor *= dstDim;
6903 return stretchingFactor;
6907class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6911 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6912 PatternRewriter &rewriter)
const override {
6914 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6918 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6919 bool srcIsScalar = !srcVectorType;
6927 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6928 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6929 ArrayRef<int64_t> srcShape =
6930 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6931 ArrayRef<int64_t> broadcastShape =
6932 broadcastOp.getResultVectorType().getShape();
6936 BroadcastableToResult::Success) {
6944 if (srcVectorType.getNumElements() != 1) {
6945 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6946 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6953 broadcastOp.getSource());
6972class FoldShapeCastOfFromElements final :
public OpRewritePattern<ShapeCastOp> {
6976 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6977 PatternRewriter &rewriter)
const override {
6978 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
6983 shapeCastOp, shapeCastOp.getResultVectorType(),
6984 fromElements.getElements());
6991void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6992 MLIRContext *context) {
6993 results.
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
6994 FoldShapeCastOfFromElements>(context);
7001LogicalResult BitCastOp::verify() {
7002 auto sourceVectorType = getSourceVectorType();
7003 auto resultVectorType = getResultVectorType();
7005 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7006 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7007 return emitOpError(
"dimension size mismatch at: ") << i;
7010 DataLayout dataLayout = DataLayout::closest(*
this);
7011 auto sourceElementBits =
7013 auto resultElementBits =
7016 if (sourceVectorType.getRank() == 0) {
7017 if (sourceElementBits != resultElementBits)
7018 return emitOpError(
"source/result bitwidth of the 0-D vector element "
7019 "types must be equal");
7020 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
7021 resultElementBits * resultVectorType.getShape().back()) {
7023 "source/result bitwidth of the minor 1-D vectors must be equal");
7029OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7035 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7036 if (getResult().
getType() == otherOp.getSource().getType())
7037 return otherOp.getSource();
7039 setOperand(otherOp.getSource());
7043 Attribute sourceConstant = adaptor.getSource();
7044 if (!sourceConstant)
7047 Type srcElemType = getSourceVectorType().getElementType();
7048 Type dstElemType = getResultVectorType().getElementType();
7050 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7051 if (floatPack.isSplat()) {
7052 auto splat = floatPack.getSplatValue<FloatAttr>();
7055 if (srcElemType.
isF16() && dstElemType.
isF32()) {
7056 uint32_t bits =
static_cast<uint32_t
>(
7057 splat.getValue().bitcastToAPInt().getZExtValue());
7059 bits = (bits << 16) | (bits & 0xffff);
7060 APInt intBits(32, bits);
7061 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7067 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7068 if (intPack.isSplat()) {
7069 auto splat = intPack.getSplatValue<IntegerAttr>();
7071 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.
isIntOrFloat()) {
7076 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7077 APInt intBits = splat.getValue().zext(dstBitWidth);
7080 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7081 intBits = (intBits << srcBitWidth) | intBits;
7095static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7096 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7097 SmallVector<int64_t, 8> res(memRefType.getShape());
7099 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7105void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
7107 result.addOperands(source);
7108 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
7109 VectorType vectorType =
7110 VectorType::get(extractShape(memRefType),
7112 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7113 memRefType.getMemorySpace()));
7116LogicalResult TypeCastOp::verify() {
7117 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
7118 if (!canonicalType.getLayout().isIdentity())
7119 return emitOpError(
"expects operand to be a memref with identity layout");
7120 if (!getResultMemRefType().getLayout().isIdentity())
7121 return emitOpError(
"expects result to be a memref with identity layout");
7122 if (getResultMemRefType().getMemorySpace() !=
7124 return emitOpError(
"expects result in same memory space");
7127 auto resultType = getResultMemRefType();
7131 "expects result and operand with same underlying scalar type: ")
7133 if (extractShape(sourceType) != extractShape(resultType))
7135 "expects concatenated result and operand shapes to be equal: ")
7144void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
7145 Value vector, ArrayRef<int64_t> permutation) {
7146 VectorType vt = llvm::cast<VectorType>(vector.
getType());
7147 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7148 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7149 for (
unsigned i = 0; i < permutation.size(); ++i) {
7150 transposedShape[i] = vt.getShape()[permutation[i]];
7151 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7154 result.addOperands(vector);
7155 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7156 transposedScalableDims));
7157 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
7161OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7164 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7165 return splat.reshape(getResultVectorType());
7182 if (getSourceVectorType() == getResultVectorType() &&
7183 isOrderPreserving(*
this))
7189LogicalResult vector::TransposeOp::verify() {
7190 VectorType vectorType = getSourceVectorType();
7191 VectorType resultType = getResultVectorType();
7192 int64_t rank = resultType.getRank();
7193 if (vectorType.getRank() != rank)
7194 return emitOpError(
"vector result rank mismatch: ") << rank;
7196 ArrayRef<int64_t> perm = getPermutation();
7197 int64_t size = perm.size();
7199 return emitOpError(
"transposition length mismatch: ") << size;
7200 SmallVector<bool, 8> seen(rank,
false);
7201 for (
const auto &ta : llvm::enumerate(perm)) {
7202 if (ta.value() < 0 || ta.value() >= rank)
7203 return emitOpError(
"transposition index out of range: ") << ta.value();
7204 if (seen[ta.value()])
7205 return emitOpError(
"duplicate position index: ") << ta.value();
7206 seen[ta.value()] =
true;
7207 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7208 return emitOpError(
"dimension size mismatch at: ") << ta.value();
7213std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7214 return llvm::to_vector<4>(getResultVectorType().
getShape());
7217void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7219 setResultRanges(getResult(), argRanges.front());
7225class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
7229 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7230 PatternRewriter &rewriter)
const override {
7232 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7233 ArrayRef<int64_t> permutation2) {
7234 SmallVector<int64_t, 4>
result;
7235 for (
auto index : permutation2)
7236 result.push_back(permutation1[index]);
7241 vector::TransposeOp parentTransposeOp =
7242 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7243 if (!parentTransposeOp)
7246 SmallVector<int64_t, 4> permutation = composePermutations(
7247 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7250 transposeOp, transposeOp.getResult().
getType(),
7251 parentTransposeOp.getVector(), permutation);
7257class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
7261 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7262 PatternRewriter &rewriter)
const override {
7263 Value splat = getScalarSplatSource(transposeOp.getVector());
7268 transposeOp, transposeOp.getResultVectorType(), splat);
7274class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
7278 LogicalResult matchAndRewrite(TransposeOp transpOp,
7279 PatternRewriter &rewriter)
const override {
7280 Value transposeSrc = transpOp.getVector();
7281 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
7282 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
7283 if (!createMaskOp && !constantMaskOp)
7288 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7291 auto maskOperands = createMaskOp.getOperands();
7292 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7296 transpOp, transpOp.getResultVectorType(), newOperands);
7301 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7305 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7311class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
7315 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7316 PatternRewriter &rewriter)
const override {
7318 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7321 if (!isOrderPreserving(transposeOp))
7324 VectorType resultType = transposeOp.getType();
7331 shapeCastOp.getSource());
7350class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
7353 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7354 PatternRewriter &rewriter)
const override {
7355 auto fromElementsOp =
7356 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7357 if (!fromElementsOp)
7360 VectorType srcTy = fromElementsOp.getDest().getType();
7361 VectorType dstTy = transposeOp.getType();
7363 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7364 int64_t rank = srcTy.getRank();
7367 SmallVector<int64_t> inversePerm(rank, 0);
7368 for (int64_t i = 0; i < rank; ++i)
7369 inversePerm[permutation[i]] = i;
7371 ArrayRef<int64_t> srcShape = srcTy.getShape();
7372 ArrayRef<int64_t> dstShape = dstTy.getShape();
7373 SmallVector<int64_t> srcIdx(rank, 0);
7374 SmallVector<int64_t> dstIdx(rank, 0);
7378 auto elementsOld = fromElementsOp.getElements();
7379 SmallVector<Value> elementsNew;
7380 int64_t dstNumElements = dstTy.getNumElements();
7381 elementsNew.reserve(dstNumElements);
7385 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7389 for (int64_t j = 0; j < rank; ++j)
7390 srcIdx[j] = dstIdx[inversePerm[j]];
7392 int64_t srcLin =
linearize(srcIdx, srcStrides);
7394 elementsNew.push_back(elementsOld[srcLin]);
7428class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7431 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7432 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7434 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7435 PatternRewriter &rewriter)
const override {
7441 "not preceded by a broadcast");
7444 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7445 VectorType outputType = transpose.getResultVectorType();
7448 bool inputIsScalar = !inputType;
7449 if (inputIsScalar) {
7455 ArrayRef<int64_t> permutation = transpose.getPermutation();
7456 ArrayRef<int64_t> inputShape = inputType.getShape();
7457 int64_t inputRank = inputType.getRank();
7458 int64_t outputRank = transpose.getType().getRank();
7459 int64_t deltaRank = outputRank - inputRank;
7462 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7463 bool notOne = inputShape[inputIndex] != 1;
7464 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7465 bool groupEndFound = notOne || prevNotOne;
7466 if (groupEndFound) {
7467 int high = inputIndex + deltaRank;
7471 for (
int i = low; i < high; ++i) {
7472 if (permutation[i] < low || permutation[i] >= high) {
7474 transpose,
"permutation not local to group");
7488 vector::BroadcastableToResult::Success &&
7489 "not broadcastable directly to transpose output");
7500void vector::TransposeOp::getCanonicalizationPatterns(
7501 RewritePatternSet &results, MLIRContext *context) {
7502 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7503 FoldTransposeSplat, FoldTransposeFromElements,
7504 FoldTransposeBroadcast>(context);
7511void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7513 assert(kind == ConstantMaskKind::AllTrue ||
7514 kind == ConstantMaskKind::AllFalse);
7515 build(builder,
result, type,
7516 kind == ConstantMaskKind::AllTrue
7518 : SmallVector<int64_t>(type.getRank(), 0));
7521LogicalResult ConstantMaskOp::verify() {
7522 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7524 if (resultType.getRank() == 0) {
7525 if (getMaskDimSizes().size() != 1)
7526 return emitError(
"array attr must have length 1 for 0-D vectors");
7527 auto dim = getMaskDimSizes()[0];
7528 if (dim != 0 && dim != 1)
7529 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7534 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7536 "must specify array attr of size equal vector result rank");
7539 auto resultShape = resultType.getShape();
7540 auto resultScalableDims = resultType.getScalableDims();
7541 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7542 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7543 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7545 "array attr of size out of bounds of vector result dimension size");
7546 if (resultScalableDims[index] && maskDimSize != 0 &&
7547 maskDimSize != resultShape[index])
7549 "only supports 'none set' or 'all set' scalable dimensions");
7553 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7554 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7555 if (anyZeros && !allZeros)
7556 return emitOpError(
"expected all mask dim sizes to be zeros, "
7557 "as a result of conjunction with zero mask dim");
7561bool ConstantMaskOp::isAllOnesMask() {
7564 if (resultType.getRank() == 0) {
7565 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7566 return getMaskDimSizes()[0] == 1;
7568 for (
const auto [resultSize, maskDimSize] :
7569 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7570 if (maskDimSize < resultSize)
7576OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7577 ArrayRef<int64_t> bounds = getMaskDimSizes();
7580 auto createBoolSplat = [&](
bool x) {
7586 if (vectorSizes.empty()) {
7587 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7588 return createBoolSplat(bounds[0] == 1);
7591 if (bounds == vectorSizes)
7592 return createBoolSplat(
true);
7593 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7594 return createBoolSplat(
false);
7595 return OpFoldResult();
7602void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7604 ArrayRef<OpFoldResult> mixedOperands) {
7605 SmallVector<Value> operands =
7607 build(builder,
result, type, operands);
7610LogicalResult CreateMaskOp::verify() {
7611 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7613 if (vectorType.getRank() == 0) {
7614 if (getNumOperands() != 1)
7616 "must specify exactly one operand for 0-D create_mask");
7617 }
else if (getNumOperands() !=
7618 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7620 "must specify an operand for each result vector dimension");
7650class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7654 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7655 PatternRewriter &rewriter)
const override {
7656 VectorType maskType = createMaskOp.getVectorType();
7657 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7658 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7661 constexpr std::array<int64_t, 1> rankZeroShape{1};
7662 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7663 if (maskType.getRank() == 0) {
7664 maskTypeDimSizes = rankZeroShape;
7665 maskTypeDimScalableFlags = rankZeroScalableDims;
7670 SmallVector<int64_t, 4> constantDims;
7671 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7676 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7678 constantDims.push_back(*intSize);
7682 if (vscaleMultiplier < maskTypeDimSizes[i])
7684 constantDims.push_back(*vscaleMultiplier);
7691 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7692 value = std::clamp<int64_t>(value, 0, maskDimSize);
7695 if (llvm::is_contained(constantDims, 0))
7696 constantDims.assign(constantDims.size(), 0);
7707void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7708 MLIRContext *context) {
7709 results.
add<CreateMaskFolder>(context);
7717 OpBuilder &builder, OperationState &
result, Value mask,
7718 Operation *maskableOp,
7719 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7720 assert(maskRegionBuilder &&
7721 "builder callback for 'maskRegion' must be present");
7723 result.addOperands(mask);
7724 OpBuilder::InsertionGuard guard(builder);
7725 Region *maskRegion =
result.addRegion();
7727 maskRegionBuilder(builder, maskableOp);
7732 Value mask, Operation *maskableOp,
7733 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7734 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7740 Value mask, Value passthru, Operation *maskableOp,
7741 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7742 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7744 result.addOperands(passthru);
7745 result.addTypes(resultTypes);
7748ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7750 result.regions.reserve(1);
7751 Region &maskRegion = *
result.addRegion();
7756 OpAsmParser::UnresolvedOperand mask;
7761 OpAsmParser::UnresolvedOperand passthru;
7763 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7770 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7781 SmallVector<Type> resultTypes;
7784 result.types.append(resultTypes);
7790 if (parsePassthru.succeeded()) {
7791 if (resultTypes.empty())
7794 "expects a result if passthru operand is provided");
7803void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7804 p <<
" " << getMask();
7806 p <<
", " << getPassthru();
7810 Block *singleBlock = &getMaskRegion().getBlocks().front();
7817 p <<
" : " << getMask().getType();
7818 if (getNumResults() > 0)
7819 p <<
" -> " << getResultTypes();
7822void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7825 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7826 MaskOp>::ensureTerminator(region, builder, loc);
7832 if (isa<vector::YieldOp>(block.
back()))
7840 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7841 MaskOp>::ensureTerminator(region, builder, loc);
7847 Operation *maskedOp = &block.
front();
7848 opBuilder.setInsertionPointToEnd(&block);
7849 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7852LogicalResult MaskOp::verify() {
7854 Block &block = getMaskRegion().getBlocks().
front();
7856 return emitOpError(
"expects a terminator within the mask region");
7859 if (numMaskRegionOps > 2)
7860 return emitOpError(
"expects only one operation to mask");
7863 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7865 return emitOpError(
"expects a terminator within the mask region");
7867 if (terminator->getNumOperands() != getNumResults())
7869 "expects number of results to match mask region yielded values");
7872 if (numMaskRegionOps == 1)
7875 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7877 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7881 return emitOpError(
"expects number of results to match maskable operation "
7882 "number of results");
7884 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7885 return emitOpError(
"expects all the results from the MaskableOpInterface "
7886 "to match all the values returned by the terminator");
7888 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7890 "expects result type to match maskable operation result type");
7893 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7894 return emitOpError(
"multiple vector results not supported");
7897 Type expectedMaskType = maskableOp.getExpectedMaskType();
7898 if (getMask().
getType() != expectedMaskType)
7900 << expectedMaskType <<
" mask for the maskable operation";
7903 Value passthru = getPassthru();
7905 if (!maskableOp.supportsPassthru())
7907 "doesn't expect a passthru argument for this maskable operation");
7910 return emitOpError(
"expects result when passthru argument is provided");
7913 return emitOpError(
"expects passthru type to match result type");
7933static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7934 SmallVectorImpl<OpFoldResult> &results) {
7935 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7938 Block *block = maskOp.getMaskBlock();
7939 auto terminator = cast<vector::YieldOp>(block->
front());
7940 if (terminator.getNumOperands() == 0)
7944 llvm::append_range(results, terminator.getOperands());
7948LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7949 SmallVectorImpl<OpFoldResult> &results) {
7950 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7960 Operation *maskableOp = getMaskableOp();
7966 llvm::append_range(results, maskableOp->
getResults());
7982class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
7985 LogicalResult matchAndRewrite(MaskOp maskOp,
7986 PatternRewriter &rewriter)
const override {
7987 if (!maskOp.isEmpty())
7990 if (!maskOp.hasPassthru())
7997 VectorType maskType = maskOp.getMask().getType();
7998 for (Type resultType : maskOp.getResultTypes()) {
7999 auto vecResultType = dyn_cast<VectorType>(resultType);
8000 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8004 Block *block = maskOp.getMaskBlock();
8005 auto terminator = cast<vector::YieldOp>(block->
front());
8006 assert(terminator.getNumOperands() == 1 &&
8007 "expected one result when passthru is provided");
8010 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8011 terminator.getOperand(0), maskOp.getPassthru());
8017void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8018 MLIRContext *context) {
8019 results.
add<CanonializeEmptyMaskOp>(context);
8025Operation *MaskOp::getMaskableOp() {
8026 Block *block = getMaskBlock();
8030 return &block->
front();
8034bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
8040LogicalResult ScanOp::verify() {
8041 VectorType srcType = getSourceType();
8042 VectorType initialType = getInitialValueType();
8044 int64_t srcRank = srcType.getRank();
8045 int64_t reductionDim = getReductionDim();
8046 if (reductionDim >= srcRank)
8048 << reductionDim <<
" has to be less than " << srcRank;
8051 int64_t initialValueRank = initialType.getRank();
8052 if (initialValueRank != srcRank - 1)
8054 << initialValueRank <<
" has to be equal to " << srcRank - 1;
8057 ArrayRef<int64_t> srcShape = srcType.getShape();
8058 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8059 SmallVector<int64_t> expectedShape;
8060 for (
int i = 0; i < srcRank; i++) {
8061 if (i != reductionDim)
8062 expectedShape.push_back(srcShape[i]);
8064 if (!llvm::equal(initialValueShapes, expectedShape)) {
8065 return emitOpError(
"incompatible input/initial value shapes");
8069 Type eltType = getDestType().getElementType();
8072 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
8079 RewritePatternSet &patterns, PatternBenefit benefit) {
8081 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8082 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8083 StridedSliceConstantMaskFolder, TransposeFolder>(
8088 CombiningKind kind, Value v1, Value acc,
8089 arith::FastMathFlagsAttr fastmath,
8096 case CombiningKind::ADD:
8098 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
8099 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8100 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8102 llvm_unreachable(
"invalid value types for ADD reduction");
8104 case CombiningKind::AND:
8106 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
8108 case CombiningKind::MAXNUMF:
8109 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8110 "expected float values");
8111 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8113 case CombiningKind::MAXIMUMF:
8114 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8115 "expected float values");
8116 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8118 case CombiningKind::MINNUMF:
8119 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8120 "expected float values");
8121 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8123 case CombiningKind::MINIMUMF:
8124 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8125 "expected float values");
8126 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8128 case CombiningKind::MAXSI:
8130 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8132 case CombiningKind::MINSI:
8134 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8136 case CombiningKind::MAXUI:
8138 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8140 case CombiningKind::MINUI:
8142 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8144 case CombiningKind::MUL:
8146 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
8147 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8148 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8150 llvm_unreachable(
"invalid value types for MUL reduction");
8152 case CombiningKind::OR:
8154 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
8156 case CombiningKind::XOR:
8158 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8162 assert(
result &&
"unknown CombiningKind");
8170void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8172 auto resultType = cast<VectorType>(
getType());
8173 if (resultType.isScalable()) {
8177 APInt zero(bitwidth, 0);
8178 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8179 ConstantIntRanges
result = {zero, high, zero, high};
8180 setResultRanges(getResult(),
result);
8210struct StepCompareFolder :
public OpRewritePattern<StepOp> {
8213 LogicalResult matchAndRewrite(StepOp stepOp,
8214 PatternRewriter &rewriter)
const override {
8215 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8217 for (OpOperand &use : stepOp.getResult().getUses()) {
8218 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8223 const unsigned stepOperandNumber = use.getOperandNumber();
8224 if (stepOperandNumber != 0)
8228 unsigned constOperandNumber = 1;
8229 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8230 std::optional<int64_t> maybeConstValue =
8232 if (!maybeConstValue.has_value())
8235 int64_t constValue = maybeConstValue.value();
8236 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8238 auto maybeSplat = [&]() -> std::optional<bool> {
8240 if ((pred == arith::CmpIPredicate::ult ||
8241 pred == arith::CmpIPredicate::uge) &&
8242 stepSize <= constValue)
8243 return pred == arith::CmpIPredicate::ult;
8246 if ((pred == arith::CmpIPredicate::ule ||
8247 pred == arith::CmpIPredicate::ugt) &&
8248 stepSize - 1 <= constValue) {
8249 return pred == arith::CmpIPredicate::ule;
8253 if ((pred == arith::CmpIPredicate::eq ||
8254 pred == arith::CmpIPredicate::ne) &&
8255 stepSize <= constValue)
8256 return pred == arith::CmpIPredicate::ne;
8258 return std::nullopt;
8261 if (!maybeSplat.has_value())
8266 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8271 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8283void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8284 MLIRContext *context) {
8285 results.
add<StepCompareFolder>(context);
8295 Operation *maskableOp) {
8296 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
8308 Operation *maskableOp, Value mask,
8313 return MaskOp::create(builder, maskableOp->
getLoc(),
8316 return MaskOp::create(builder, maskableOp->
getLoc(),
8329 Value newValue, Value passthru) {
8333 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
8334 mask, newValue, passthru);
8341#define GET_ATTRDEF_CLASSES
8342#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8344#define GET_OP_CLASSES
8345#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)