41#include "llvm/ADT/ArrayRef.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Casting.h"
52#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
54#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
75 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
77 for (
bool b : denseElts.getValues<
bool>())
80 else if (!
b && val <= 0)
94 auto shape = m.getType().getShape();
97 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
98 if (maskIdx < dimSize)
111 auto maskOperands = m.getOperands();
112 for (
Value operand : maskOperands) {
113 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
128 vector::YieldOp::create(builder, loc);
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
138 case CombiningKind::MINUI:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
190 VectorType vectorType) {
193 if (shapedType.getRank() == 0 &&
199 shapedType.getRank(),
201 shapedType.getContext());
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
255 if (transferA.getVectorType() != transferB.getVectorType())
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
264 if (i < rankOffset) {
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
272 if (testDynamicValueUsingBounds) {
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
295 if (testDynamicValueUsingBounds) {
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (std::abs(computeDelta.value()) >= vectorDim)
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
321 testDynamicValueUsingBounds);
331 for (
auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
334 if (posInDim < dimSize + offsetInDim)
338 posInDim = offsetInDim;
348 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
350 assert(constOp &&
"Unexpected non-constant index");
351 return constOp.value();
361 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
373 llvm::transform(foldResults, std::back_inserter(values),
375 if (
auto attr = dyn_cast<Attribute>(foldResult))
377 builder, loc, cast<IntegerAttr>(attr).getInt())
380 return cast<Value>(foldResult);
393 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
395 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
405 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
407 if (intAttr.getType() != expectedType)
408 return IntegerAttr::get(expectedType, intAttr.getInt());
414 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415 auto intType = dyn_cast<IntegerType>(expectedType);
419 APFloat floatVal = floatAttr.getValue();
420 APInt intVal = floatVal.bitcastToAPInt();
421 return IntegerAttr::get(expectedType, intVal);
460struct VectorInlinerInterface :
public DialectInlinerInterface {
461 using DialectInlinerInterface::DialectInlinerInterface;
470void VectorDialect::initialize() {
472#define GET_ATTRDEF_LIST
473#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
478#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
481 addInterfaces<VectorInlinerInterface>();
483 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
486 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
488 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
498 if (isa<ub::PoisonAttrInterface>(value))
501 return arith::ConstantOp::materialize(builder, value, type, loc);
517void vector::MultiDimReductionOp::build(
OpBuilder &builder,
520 CombiningKind kind) {
522 for (
const auto &en : llvm::enumerate(reductionMask))
524 reductionDims.push_back(en.index());
525 build(builder,
result, kind, source,
acc, reductionDims);
528OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
530 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
535std::optional<SmallVector<int64_t, 4>>
536MultiDimReductionOp::getShapeForUnroll() {
537 return llvm::to_vector<4>(getSourceVectorType().
getShape());
540LogicalResult MultiDimReductionOp::verify() {
543 Type inferredReturnType;
544 auto sourceScalableDims = getSourceVectorType().getScalableDims();
545 for (
auto [dimIdx, dimSize] :
546 llvm::enumerate(getSourceVectorType().
getShape()))
547 if (!llvm::any_of(getReductionDims(),
548 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
549 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
551 targetShape.push_back(dimSize);
552 scalableDims.push_back(sourceScalableDims[dimIdx]);
555 if (targetShape.empty())
556 inferredReturnType = getSourceVectorType().getElementType();
558 inferredReturnType = VectorType::get(
559 targetShape, getSourceVectorType().
getElementType(), scalableDims);
560 if (
getType() != inferredReturnType)
562 <<
" is incompatible with source type "
563 << getSourceVectorType();
569Type MultiDimReductionOp::getExpectedMaskType() {
570 auto vecType = getSourceVectorType();
571 return VectorType::get(vecType.getShape(),
572 IntegerType::get(vecType.getContext(), 1),
573 vecType.getScalableDims());
582struct ElideUnitDimsInMultiDimReduction
586 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
587 PatternRewriter &rewriter)
const override {
588 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
589 for (
const auto &dim :
enumerate(shape)) {
590 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
595 OpBuilder::InsertionGuard guard(rewriter);
598 if (reductionOp.isMasked()) {
600 rootOp = reductionOp.getMaskingOp();
601 mask = reductionOp.getMaskingOp().getMask();
603 rootOp = reductionOp;
606 Location loc = reductionOp.getLoc();
607 Value acc = reductionOp.getAcc();
609 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
611 VectorType newMaskType =
612 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
613 dstVecType.getScalableDims());
614 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
616 cast = vector::ShapeCastOp::create(
617 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
622 mask = vector::ExtractOp::create(rewriter, loc, mask);
623 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
628 cast,
nullptr, mask);
635void MultiDimReductionOp::getCanonicalizationPatterns(
637 results.
add<ElideUnitDimsInMultiDimReduction>(context);
646 arith::FastMathFlags fastMathFlags) {
652 arith::FastMathFlags fastMathFlags) {
654 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
658LogicalResult ReductionOp::verify() {
660 int64_t rank = getSourceVectorType().getRank();
662 return emitOpError(
"unsupported reduction rank: ") << rank;
665 Type eltType = getDest().getType();
668 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
677Type ReductionOp::getExpectedMaskType() {
678 auto vecType = getSourceVectorType();
679 return VectorType::get(vecType.getShape(),
680 IntegerType::get(vecType.getContext(), 1),
681 vecType.getScalableDims());
688 case arith::AtomicRMWKind::addf:
689 case arith::AtomicRMWKind::addi:
690 return vector::ReductionOp::create(builder,
vector.getLoc(),
691 CombiningKind::ADD,
vector);
692 case arith::AtomicRMWKind::mulf:
693 case arith::AtomicRMWKind::muli:
694 return vector::ReductionOp::create(builder,
vector.getLoc(),
695 CombiningKind::MUL,
vector);
696 case arith::AtomicRMWKind::minimumf:
697 return vector::ReductionOp::create(builder,
vector.getLoc(),
698 CombiningKind::MINIMUMF,
vector);
699 case arith::AtomicRMWKind::mins:
700 return vector::ReductionOp::create(builder,
vector.getLoc(),
701 CombiningKind::MINSI,
vector);
702 case arith::AtomicRMWKind::minu:
703 return vector::ReductionOp::create(builder,
vector.getLoc(),
704 CombiningKind::MINUI,
vector);
705 case arith::AtomicRMWKind::maximumf:
706 return vector::ReductionOp::create(builder,
vector.getLoc(),
707 CombiningKind::MAXIMUMF,
vector);
708 case arith::AtomicRMWKind::maxs:
709 return vector::ReductionOp::create(builder,
vector.getLoc(),
710 CombiningKind::MAXSI,
vector);
711 case arith::AtomicRMWKind::maxu:
712 return vector::ReductionOp::create(builder,
vector.getLoc(),
713 CombiningKind::MAXUI,
vector);
714 case arith::AtomicRMWKind::andi:
715 return vector::ReductionOp::create(builder,
vector.getLoc(),
716 CombiningKind::AND,
vector);
717 case arith::AtomicRMWKind::ori:
718 return vector::ReductionOp::create(builder,
vector.getLoc(),
719 CombiningKind::OR,
vector);
720 case arith::AtomicRMWKind::minnumf:
721 return vector::ReductionOp::create(builder,
vector.getLoc(),
722 CombiningKind::MINNUMF,
vector);
723 case arith::AtomicRMWKind::maxnumf:
724 return vector::ReductionOp::create(builder,
vector.getLoc(),
725 CombiningKind::MAXNUMF,
vector);
726 case arith::AtomicRMWKind::xori:
727 return vector::ReductionOp::create(builder,
vector.getLoc(),
728 CombiningKind::XOR,
vector);
736std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
737 return llvm::to_vector<4>(getSourceVectorType().
getShape());
744 LogicalResult matchAndRewrite(ReductionOp reductionOp,
749 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
752 if (maskableOp.isMasked()) {
754 rootOp = maskableOp.getMaskingOp();
755 mask = maskableOp.getMaskingOp().getMask();
757 rootOp = reductionOp;
760 auto vectorType = reductionOp.getSourceVectorType();
761 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
764 Location loc = reductionOp.getLoc();
766 mask = ExtractOp::create(rewriter, loc, mask);
767 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
769 if (Value acc = reductionOp.getAcc())
772 reductionOp.getFastmathAttr(), mask);
782 results.
add<ElideSingleElementReduction>(context);
796 getIndexingMapsAttrName(
result.name),
800 getIteratorTypesAttrName(
result.name),
803 return IteratorTypeAttr::get(builder.getContext(), t);
812 ContractionOp::getDefaultKind());
818 ArrayAttr iteratorTypes, CombiningKind kind) {
821 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
822 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
824 CombiningKindAttr::get(builder.
getContext(), kind));
835 DictionaryAttr dictAttr;
849 result.attributes.append(dictAttr.getValue().begin(),
850 dictAttr.getValue().end());
856 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
857 result.attributes.get(getIteratorTypesAttrName(
result.name)));
858 if (!iteratorTypes) {
860 <<
"expected " << getIteratorTypesAttrName(
result.name)
861 <<
" array attribute";
866 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
867 auto maybeIteratorType = symbolizeIteratorType(s);
868 if (!maybeIteratorType.has_value())
869 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
871 iteratorTypeAttrs.push_back(
872 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
874 result.attributes.set(getIteratorTypesAttrName(
result.name),
877 if (!
result.attributes.get(getKindAttrName(
result.name))) {
879 getKindAttrName(
result.name),
880 CombiningKindAttr::get(
result.getContext(),
881 ContractionOp::getDefaultKind()));
883 if (masksInfo.empty())
885 if (masksInfo.size() != 2)
887 "expected zero or exactly 2 vector mask operands");
888 auto lhsType = llvm::cast<VectorType>(types[0]);
889 auto rhsType = llvm::cast<VectorType>(types[1]);
891 std::array<VectorType, 2> maskTypes = {
901 auto attrNames = getTraitAttrNames();
903 traitAttrsSet.insert_range(attrNames);
905 for (
auto attr : (*this)->getAttrs()) {
906 if (attr.getName() == getIteratorTypesAttrName()) {
908 llvm::cast<ArrayAttr>(attr.getValue())
909 .getAsValueRange<IteratorTypeAttr, IteratorType>();
915 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
916 return StringAttr::get(
getContext(), stringifyIteratorType(t));
919 attrs.emplace_back(getIteratorTypesAttrName(),
920 ArrayAttr::get(
getContext(), iteratorTypeNames));
921 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
922 attrs.push_back(attr);
925 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
926 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
927 p << getRhs() <<
", " << getAcc();
930 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
935 const std::vector<std::pair<int64_t, int64_t>> &map) {
936 for (
auto &dimPair : map) {
937 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
938 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
939 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
946 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
948 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
949 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
952 for (
auto &dimPair : contractingDimMap) {
953 lhsContractingDimSet.insert(dimPair.first);
954 rhsContractingDimSet.insert(dimPair.second);
957 llvm::make_second_range(batchDimMap));
961 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
962 if (lhsContractingDimSet.count(i) > 0)
964 expectedResultDims.push_back(lhsType.getDimSize(i));
968 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
969 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
971 expectedResultDims.push_back(rhsType.getDimSize(i));
975 if (expectedResultDims.empty()) {
977 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
978 return op.emitOpError(
"invalid accumulator/result vector shape");
981 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
982 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
983 if (!resVectorType || !accVectorType)
984 return op.emitOpError(
"invalid accumulator/result vector shape");
990 AffineMap lhsMap = op.getIndexingMapsArray()[0];
991 AffineMap rhsMap = op.getIndexingMapsArray()[1];
993 return op.emitOpError(
994 "expected all dimensions to be either a LHS or a RHS dimension");
997 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
998 VectorType v = pair.first;
999 auto map = pair.second;
1000 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1001 unsigned pos = map.getDimPosition(idx);
1006 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1007 return op.emitOpError(
"expected all dimensions to get an extent as "
1008 "either a LHS or a RHS dimension");
1010 AffineMap resMap = op.getIndexingMapsArray()[2];
1015 assert(llvm::all_of(expectedMap.
getResults(),
1016 llvm::IsaPred<AffineConstantExpr>) &&
1017 "expected constant extent along all dimensions.");
1019 auto expectedShape = llvm::to_vector<4>(
1021 return cast<AffineConstantExpr>(e).getValue();
1024 VectorType::get(expectedShape, resVectorType.getElementType(),
1025 resVectorType.getScalableDims());
1026 if (resVectorType != expected || accVectorType != expected)
1027 return op.emitOpError(
1028 "invalid accumulator/result vector shape, expected: ")
1034LogicalResult ContractionOp::verify() {
1035 VectorType lhsType = getLhsType();
1036 VectorType rhsType = getRhsType();
1037 Type accType = getAccType();
1038 Type resType = getResultType();
1040 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1041 if (!lhsType.getElementType().isSignlessInteger())
1042 return emitOpError(
"only supports signless integer types");
1046 if (getIndexingMapsArray().size() != 3)
1047 return emitOpError(
"expected an indexing map for each vector operand");
1052 unsigned numIterators = getIteratorTypes().getValue().size();
1053 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1054 auto index = it.index();
1055 auto map = it.value();
1056 if (map.getNumSymbols() != 0)
1058 <<
index <<
" to have no symbols";
1059 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1060 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1063 if (map.getNumDims() != numIterators)
1065 <<
index <<
" to have " << numIterators <<
" number of inputs";
1066 if (map.getNumResults() != rank)
1068 <<
index <<
" to have " << rank <<
" number of outputs";
1069 if (!map.isProjectedPermutation())
1071 <<
index <<
" to be a projected permutation of its inputs";
1074 auto contractingDimMap = getContractingDimMap();
1075 auto batchDimMap = getBatchDimMap();
1078 if (contractingDimMap.empty())
1079 return emitOpError(
"expected at least one contracting dimension pair");
1082 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1083 return emitOpError(
"invalid contracting dimension map");
1087 return emitOpError(
"invalid batch dimension map");
1091 contractingDimMap, batchDimMap)))
1094 if (!getKindAttr()) {
1095 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1096 "'vector.kind<add>')");
1100 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1101 auto elementType = vectorType ? vectorType.getElementType() : resType;
1103 return emitOpError(
"unsupported contraction type");
1106 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1113Type ContractionOp::getExpectedMaskType() {
1114 auto indexingMaps = this->getIndexingMapsArray();
1117 VectorType lhsType = this->getLhsType();
1118 VectorType rhsType = this->getRhsType();
1120 unsigned numVecDims = lhsIdxMap.
getNumDims();
1126 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1129 lhsType.getScalableDims()[dimIdx];
1131 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1134 rhsType.getScalableDims()[dimIdx];
1137 assert(ShapedType::isStaticShape(maskShape) &&
1138 "Mask shape couldn't be computed");
1140 return VectorType::get(maskShape,
1141 IntegerType::get(lhsType.getContext(), 1),
1142 maskShapeScalableDims);
1147 getIteratorTypesAttrName(), getKindAttrName()};
1157static std::vector<std::pair<int64_t, int64_t>>
1159 IteratorType targetIteratorType,
MLIRContext *context) {
1160 std::vector<std::pair<int64_t, int64_t>> dimMap;
1161 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1162 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1163 if (iteratorType != targetIteratorType)
1169 if (lhsDim >= 0 && rhsDim >= 0)
1170 dimMap.emplace_back(lhsDim, rhsDim);
1175void ContractionOp::getIterationBounds(
1177 auto lhsShape = getLhsType().getShape();
1178 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1180 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1183 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1184 if (iteratorType == IteratorType::reduction) {
1187 assert(lhsDimIndex >= 0);
1188 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1193 assert(resDimIndex >= 0);
1194 assert(resVectorType !=
nullptr);
1195 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1199void ContractionOp::getIterationIndexMap(
1201 unsigned numMaps = getIndexingMapsArray().size();
1202 iterationIndexMap.resize(numMaps);
1203 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1204 auto index = it.index();
1205 auto map = it.value();
1206 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1207 auto dim = cast<AffineDimExpr>(map.getResult(i));
1208 iterationIndexMap[
index][dim.getPosition()] = i;
1213std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1215 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1219std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1221 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1225std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1227 getIterationBounds(
shape);
1249template <
typename AddOpType>
1255 auto canonicalize = [&](
Value maybeContraction,
1256 Value otherOperand) -> vector::ContractionOp {
1257 vector::ContractionOp contractionOp =
1258 dyn_cast_or_null<vector::ContractionOp>(
1261 return vector::ContractionOp();
1262 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1263 contractionOp.getAcc().getDefiningOp())) {
1264 if (maybeZero.getValue() ==
1265 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1267 bvm.
map(contractionOp.getAcc(), otherOperand);
1268 auto newContraction =
1269 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1270 rewriter.
replaceOp(addOp, newContraction.getResult());
1271 return newContraction;
1274 return vector::ContractionOp();
1277 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1278 vector::ContractionOp
contract = canonicalize(a,
b);
1303 setResultRanges(getResult(), argRanges.front());
1308 auto vectorTy = cast<VectorType>(source.
getType());
1333 build(builder,
result, source, dynamicPos,
1338ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1339 ExtractOp::Adaptor adaptor,
1341 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1342 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1343 vectorType.getRank()) {
1344 inferredReturnTypes.push_back(vectorType.getElementType());
1346 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1347 vectorType.getRank());
1348 inferredReturnTypes.push_back(VectorType::get(
1349 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1350 vectorType.getScalableDims().drop_front(n)));
1358 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1359 return vectorType && vectorType.getShape().equals({1}) &&
1360 vectorType.getElementType() == r.front();
1362 if (l.size() == 1 && r.size() == 1 &&
1363 (isCompatible(l, r) || isCompatible(r, l)))
1368LogicalResult vector::ExtractOp::verify() {
1369 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1370 if (resTy.getRank() == 0)
1372 "expected a scalar instead of a 0-d vector as the result type");
1375 auto dynamicMarkersCount =
1376 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1377 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1379 "mismatch between dynamic and static positions (kDynamic marker but no "
1380 "corresponding dynamic position) -- this can only happen due to an "
1381 "incorrect fold/rewrite");
1382 auto position = getMixedPosition();
1383 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1385 "expected position attribute of rank no greater than vector rank");
1386 for (
auto [idx, pos] : llvm::enumerate(position)) {
1387 if (
auto attr = dyn_cast<Attribute>(pos)) {
1388 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1390 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1391 return emitOpError(
"expected position attribute #")
1393 <<
" to be a non-negative integer smaller than the "
1394 "corresponding vector dimension or poison (-1)";
1401template <
typename IntType>
1403 return llvm::to_vector<4>(llvm::map_range(
1404 arrayAttr.getAsRange<IntegerAttr>(),
1405 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1411 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1415 if (extractOp.hasDynamicPosition())
1419 ExtractOp currentOp = extractOp;
1421 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1422 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1425 if (currentOp.hasDynamicPosition())
1428 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1430 extractOp.setOperand(0, currentOp.getSource());
1433 std::reverse(globalPosition.begin(), globalPosition.end());
1434 extractOp.setStaticPosition(globalPosition);
1446class ExtractFromInsertTransposeChainState {
1448 ExtractFromInsertTransposeChainState(ExtractOp e);
1457 template <
typename ContainerA,
typename ContainerB>
1458 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1459 return a.size() <=
b.size() &&
1460 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1467 template <
typename ContainerA,
typename ContainerB>
1468 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1469 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1470 if (elemA < 0 || elemB < 0)
1481 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1485 void updateStateForNextIteration(Value v) {
1492 LogicalResult handleTransposeOp();
1495 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1510 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1515 Value tryToFoldExtractOpInPlace(Value source);
1517 ExtractOp extractOp;
1519 int64_t extractedRank;
1521 InsertOp nextInsertOp;
1522 TransposeOp nextTransposeOp;
1532 SmallVector<int64_t> sentinels;
1533 SmallVector<int64_t> extractPosition;
1537ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1539 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1540 extractedRank(extractOp.getNumIndices()) {
1541 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1542 sentinels.reserve(vectorRank - extractedRank);
1543 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1544 sentinels.push_back(-(i + 1));
1546 extractOp.getStaticPosition().end());
1552LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1554 if (extractOp.hasDynamicPosition())
1557 if (!nextTransposeOp)
1560 nextTransposeOp.getPermutation(), extractOp.getContext()));
1567ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1570 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1573 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1574 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1577 res = nextInsertOp.getValueToStore();
1586ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1588 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1591 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1601 res = nextInsertOp.getValueToStore();
1609Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1612 if (extractOp.hasDynamicPosition())
1616 bool nothingToFold = (source == extractOp.getSource());
1617 if (nothingToFold || !canFold())
1621 OpBuilder
b(extractOp.getContext());
1622 extractOp.setStaticPosition(
1624 extractOp.getSourceMutable().assign(source);
1625 return extractOp.getResult();
1629Value ExtractFromInsertTransposeChainState::fold() {
1631 if (extractOp.hasDynamicPosition())
1634 Value valueToExtractFrom = extractOp.getSource();
1635 updateStateForNextIteration(valueToExtractFrom);
1636 while (nextInsertOp || nextTransposeOp) {
1639 if (succeeded(handleTransposeOp())) {
1640 valueToExtractFrom = nextTransposeOp.getVector();
1641 updateStateForNextIteration(valueToExtractFrom);
1647 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1652 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1653 return tryToFoldExtractOpInPlace(
result);
1657 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1663 valueToExtractFrom = nextInsertOp.getDest();
1664 updateStateForNextIteration(valueToExtractFrom);
1667 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1672 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1673 auto vecType = dyn_cast<VectorType>(type);
1674 return vecType && vecType.getRank() == 0;
1684 if (isa<BroadcastOp>(op))
1687 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1695 VectorType srcType = shapeCast.getSourceVectorType();
1697 uint64_t srcRank = srcType.getRank();
1699 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1725 Operation *defOp = extractOp.getSource().getDefiningOp();
1732 if (extractOp.getType() == input.
getType())
1738 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1739 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1740 unsigned inputRank = inputType ? inputType.getRank() : 0;
1741 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1742 unsigned extractRank = extractType ? extractType.getRank() : 0;
1745 if (extractRank > inputRank)
1749 assert(inputType &&
"input must be a vector type because of previous checks");
1758 extractType.getShape() != inputShape.take_back(extractRank))
1763 unsigned deltaOverall = inputRank - extractRank;
1764 unsigned deltaBroadcast = broadcastRank - inputRank;
1768 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1769 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1772 extractOp->setOperands(
1773 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1774 extractOp.setStaticPosition(staticPos);
1775 return extractOp.getResult();
1791 if (extractOp.hasDynamicPosition())
1794 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1799 if (shuffleOp.getResultVectorType().getRank() != 1)
1802 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1803 auto shuffleMask = shuffleOp.getMask();
1804 int64_t extractIdx = extractOp.getStaticPosition()[0];
1805 int64_t shuffleIdx = shuffleMask[extractIdx];
1808 if (shuffleIdx < inputVecSize) {
1809 extractOp.setOperand(0, shuffleOp.getV1());
1810 extractOp.setStaticPosition({shuffleIdx});
1812 extractOp.setOperand(0, shuffleOp.getV2());
1813 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1816 return extractOp.getResult();
1822 if (extractOp.hasDynamicPosition())
1825 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1830 auto getDimReverse = [](VectorType type,
int64_t n) {
1831 return type.getShape().take_back(n + 1).front();
1834 llvm::isa<VectorType>(extractOp.getType())
1835 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1837 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1839 if (destinationRank > 0) {
1840 auto destinationType =
1841 llvm::cast<VectorType>(extractOp.getResult().getType());
1842 for (
int64_t i = 0; i < destinationRank; i++) {
1846 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1847 getDimReverse(destinationType, i))
1854 std::reverse(extractedPos.begin(), extractedPos.end());
1857 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1858 strides.push_back(stride);
1860 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1868 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1870 for (
int64_t i = 0; i < numDimension; i++) {
1871 newStrides.push_back(stride);
1873 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1875 std::reverse(newStrides.begin(), newStrides.end());
1879 extractOp.setStaticPosition(newPosition);
1880 extractOp.setOperand(0, shapeCastOp.getSource());
1881 return extractOp.getResult();
1887 if (extractOp.hasDynamicPosition())
1890 auto extractStridedSliceOp =
1891 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1892 if (!extractStridedSliceOp)
1901 if (extractStridedSliceOp.hasNonUnitStrides())
1907 while (!sliceOffsets.empty()) {
1908 size_t lastOffset = sliceOffsets.size() - 1;
1909 if (sliceOffsets.back() != 0 ||
1910 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1911 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1913 sliceOffsets.pop_back();
1915 unsigned destinationRank = 0;
1916 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1917 destinationRank = vecType.getRank();
1920 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1921 sliceOffsets.size())
1925 assert(extractedPos.size() >= sliceOffsets.size());
1926 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1927 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1928 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1932 extractOp.setStaticPosition(extractedPos);
1933 return extractOp.getResult();
1939 if (extractOp.hasDynamicPosition())
1943 llvm::isa<VectorType>(extractOp.getType())
1944 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1946 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1956 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1957 insertOp.getSourceVectorType().getRank();
1958 if (destinationRank > insertOp.getSourceVectorType().getRank())
1963 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1964 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1967 bool disjoint =
false;
1969 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1970 int64_t start = insertOffsets[dim];
1972 (dim < insertRankDiff)
1974 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1976 int64_t offset = extractOffsets[dim];
1978 if (start <= offset && offset < end) {
1979 if (dim >= insertRankDiff)
1980 offsetDiffs.push_back(offset - start);
1991 insertOp.getSourceVectorType().getRank() - destinationRank;
1992 for (
int64_t i = 0; i < destinationRank; i++) {
1993 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1994 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1998 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2001 extractOp.setStaticPosition(offsetDiffs);
2002 return extractOp.getResult();
2006 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2019 if (extractOp.hasDynamicPosition())
2023 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2024 if (!fromElementsOp)
2028 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2029 if (vecType.isScalable())
2033 int64_t rank = vecType.getRank();
2035 if (extractOp.getType() != vecType.getElementType())
2038 "unexpected number of indices");
2043 for (
int i = rank - 1; i >= 0; --i) {
2044 flatIndex +=
indices[i] * stride;
2045 stride *= vecType.getDimSize(i);
2047 return fromElementsOp.getElements()[flatIndex];
2052template <
typename OpType,
typename AdaptorType>
2055 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2056 OperandRange dynamicPosition = op.getDynamicPosition();
2059 if constexpr (std::is_same_v<OpType, ExtractOp>)
2060 vectorShape = op.getSourceVectorType().getShape();
2065 if (!dynamicPosition.size())
2072 bool opChange =
false;
2073 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2074 if (ShapedType::isStatic(staticPosition[i]))
2078 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2079 int64_t value = attr.getInt();
2083 staticPosition[i] = attr.getInt();
2088 operands.push_back(position);
2092 op.setStaticPosition(staticPosition);
2093 op.getOperation()->setOperands(operands);
2095 return op.getResult();
2105 if (!is_contained(staticPos, poisonVal))
2108 return ub::PoisonAttr::get(context);
2113 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2122 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2127 if (denseAttr.isSplat()) {
2129 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2134 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2135 if (vecTy.isScalable())
2138 if (extractOp.hasDynamicPosition()) {
2153 copy(extractOp.getStaticPosition(), completePositions.begin());
2156 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2159 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2161 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2164 newAttr = *denseValuesBegin;
2170OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2174 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2181 SmallVector<Value> operands = {getSource()};
2185 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2191 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2206 return inplaceFolded;
2212class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2216 LogicalResult matchAndRewrite(ExtractOp extractOp,
2217 PatternRewriter &rewriter)
const override {
2220 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2226 BroadcastableToResult::Success)
2235class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2239 LogicalResult matchAndRewrite(ExtractOp extractOp,
2240 PatternRewriter &rewriter)
const override {
2242 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2246 VectorType extractedMaskType =
2247 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2249 if (!extractedMaskType)
2252 auto maskOperands = createMaskOp.getOperands();
2253 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2254 VectorType maskType = createMaskOp.getVectorType();
2256 bool containsUnknownDims =
false;
2259 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2261 int64_t pos = extractOpPos[dimIdx];
2262 Value operand = maskOperands[dimIdx];
2263 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2266 containsUnknownDims =
true;
2270 int64_t createMaskBound =
2271 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2273 if (pos != ShapedType::kDynamic) {
2276 allFalse |= pos >= createMaskBound;
2277 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2281 containsUnknownDims =
true;
2288 }
else if (!containsUnknownDims) {
2290 extractOp, extractedMaskType,
2291 maskOperands.drop_front(extractOpPos.size()));
2301LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2302 PatternRewriter &rewriter) {
2303 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2307 VectorType sourceType = castOp.getSourceVectorType();
2308 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2312 if (sourceType.getNumElements() != targetType.getNumElements())
2316 castOp.getSource());
2326LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2327 PatternRewriter &rewriter) {
2329 if (extractOp.hasDynamicPosition())
2333 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2338 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2339 if (!fromElementsOp)
2341 VectorType inputType = fromElementsOp.getType();
2344 if (resultType.isScalable() || inputType.isScalable())
2349 SmallVector<int64_t> firstElementPos =
2350 llvm::to_vector(extractOp.getStaticPosition());
2351 firstElementPos.append(resultType.getRank(), 0);
2354 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2355 flatIndex += firstElementPos[i] * stride;
2356 stride *= inputType.getDimSize(i);
2361 extractOp, resultType,
2362 fromElementsOp.getElements().slice(flatIndex,
2363 resultType.getNumElements()));
2369void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2370 MLIRContext *context) {
2371 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2372 results.
add(foldExtractFromShapeCastToShapeCast);
2373 results.
add(foldExtractFromFromElements);
2378 for (
auto attr : arrayAttr)
2379 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2386std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2397 if (operands.empty())
2400 return llvm::all_of(operands, [&](
Value operand) {
2402 return currentDef == defOp;
2420 auto fromElementsOp =
2421 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2422 if (!fromElementsOp)
2425 llvm::append_range(results, fromElementsOp.getElements());
2442 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2446 if (isa<VectorType>(bcastOp.getSource().getType()))
2449 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2451 Value scalar = bcastOp.getSource();
2452 results.assign(resultVecType.getNumElements(), scalar);
2456LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2457 SmallVectorImpl<OpFoldResult> &results) {
2464ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2465 ToElementsOp::Adaptor adaptor,
2466 SmallVectorImpl<Type> &inferredReturnTypes) {
2467 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2468 Type elType = vecType.getElementType();
2469 inferredReturnTypes.append(vecType.getNumElements(), elType);
2491 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2496 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2500 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2505 int64_t dstRank = dstShape.size();
2506 int64_t srcRank = srcShape.size();
2509 auto srcElems = vector::ToElementsOp::create(
2510 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2512 int64_t dstCount = llvm::product_of(dstShape);
2515 replacements.reserve(dstCount);
2540 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2543 for (
int64_t k = 0; k < srcRank; ++k)
2544 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2547 replacements.push_back(srcElems.getResult(srcLin));
2550 rewriter.
replaceOp(toElementsOp, replacements);
2555void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2556 MLIRContext *context) {
2557 results.
add<ToElementsOfBroadcast>(context);
2577 OperandRange fromElemsOperands = fromElementsOp.getElements();
2578 if (fromElemsOperands.empty())
2581 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2589 Value toElementsInput = toElementsOp.getSource();
2590 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2591 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2592 return toElementsInput;
2612 if (llvm::any_of(elements, [](
Attribute attr) {
2613 return !attr || isa<ub::PoisonAttrInterface>(attr);
2618 auto destVecType = fromElementsOp.getDest().getType();
2619 auto destEltType = destVecType.getElementType();
2620 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2625 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2632OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2649 if (!llvm::all_equal(fromElementsOp.getElements()))
2652 fromElementsOp, fromElementsOp.getType(),
2653 fromElementsOp.getElements().front());
2681 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2685 if (fromElements.getType().getNumElements() == 1)
2696 for (
auto [insertIndex, element] :
2697 llvm::enumerate(fromElements.getElements())) {
2700 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2703 "element not from vector.extract");
2708 if (insertIndex == 0) {
2709 source = extractOp.getSource();
2710 }
else if (extractOp.getSource() != source) {
2712 "element from different vector");
2716 int64_t rank = position.size();
2717 assert(rank == source.getType().getRank() &&
2718 "scalar extract must have full rank position");
2729 if (insertIndex == 0) {
2730 const int64_t numElms = fromElements.getType().getNumElements();
2733 while (
index > 0 && position[
index - 1] == 0 &&
2734 numSuffixElms < numElms) {
2735 numSuffixElms *= source.getType().getDimSize(
index - 1);
2738 if (numSuffixElms != numElms) {
2740 fromElements,
"elements do not form a suffix of source");
2742 expectedPosition = llvm::to_vector(position);
2743 combinedPosition = position.drop_back(rank -
index);
2747 else if (expectedPosition != position) {
2749 fromElements,
"elements not in ascending order (static order)");
2751 increment(expectedPosition, source.getType().getShape());
2754 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2755 fromElements.getLoc(), source, combinedPosition);
2758 fromElements, fromElements.getType(), extracted);
2766 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2785void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2787 setResultRanges(getResult(), argRanges.front());
2790std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2791 return llvm::to_vector<4>(getResultVectorType().
getShape());
2796static llvm::SetVector<int64_t>
2799 int64_t rankDiff = dstShape.size() - srcShape.size();
2802 for (
auto [s1, s2] :
2803 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2805 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2813llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2815 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2818 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2834Value BroadcastOp::createOrFoldBroadcastOp(
2835 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
2836 const llvm::SetVector<int64_t> &broadcastedDims) {
2837 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2840 SmallVector<int64_t> checkShape;
2841 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2842 if (broadcastedDims.contains(i))
2844 checkShape.push_back(dstShape[i]);
2846 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2847 "ill-formed broadcastedDims contains values not confined to "
2850 Location loc = value.
getLoc();
2852 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2853 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2856 if (!srcVectorType) {
2857 assert(checkShape.empty() &&
2858 "ill-formed createOrFoldBroadcastOp arguments");
2859 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2862 assert(srcVectorType.getShape().equals(checkShape) &&
2863 "ill-formed createOrFoldBroadcastOp arguments");
2873 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2874 broadcastShape.reserve(dstShape.size());
2890 int64_t nextSrcShapeDim = broadcastedDims.size();
2891 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2892 if (broadcastedDims.contains(i)) {
2897 broadcastShape.push_back(dstShape[i]);
2898 permutation[i] = broadcastShape.size() - 1;
2904 permutation[i] = nextSrcShapeDim++;
2908 llvm::append_range(broadcastShape, srcVectorType.getShape());
2913 "unexpected \"dim-1\" broadcast");
2915 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2917 vector::BroadcastableToResult::Success &&
2918 "must be broadcastable");
2919 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2922 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2923 if (permutation[i] != i)
2924 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2930 Type srcType, VectorType dstVectorType,
2931 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2933 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2937 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2941 int64_t srcRank = srcVectorType.getRank();
2942 int64_t dstRank = dstVectorType.getRank();
2943 if (srcRank > dstRank)
2947 int64_t lead = dstRank - srcRank;
2948 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2951 bool foundMismatchingDims =
false;
2954 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2955 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2956 if (srcDim != 1 && srcDim != dstDim)
2957 foundMismatchingDims =
true;
2960 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2961 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2962 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2965 (srcDimScalableFlag != dstDimScalableFlag &&
2966 (srcDim != 1 || srcDimScalableFlag)))
2967 foundMismatchingDims =
true;
2969 if (foundMismatchingDims) {
2970 if (mismatchingDims !=
nullptr) {
2971 mismatchingDims->first.dim = srcDim;
2972 mismatchingDims->first.isScalable = srcDimScalableFlag;
2974 mismatchingDims->second.dim = dstDim;
2975 mismatchingDims->second.isScalable = dstDimScalableFlag;
2984LogicalResult BroadcastOp::verify() {
2985 std::pair<VectorDim, VectorDim> mismatchingDims;
2987 getSourceType(), getResultVectorType(), &mismatchingDims);
2991 return emitOpError(
"source rank higher than destination rank");
2994 << (mismatchingDims.first.isScalable ?
"[" :
"")
2995 << mismatchingDims.first.dim
2996 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2997 << (mismatchingDims.second.isScalable ?
"[" :
"")
2998 << mismatchingDims.second.dim
2999 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3002 return emitOpError(
"source type is not a vector");
3003 llvm_unreachable(
"unexpected vector.broadcast op error");
3010 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3014 VectorType srcType = srcShapeCast.getSourceVectorType();
3015 VectorType destType = broadcastOp.getResultVectorType();
3023 srcShapeCast.getResultVectorType().getShape();
3026 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3027 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3028 shapecastShape.take_back(numTrailingDims)))
3031 assert(all_of(srcShape.drop_back(numTrailingDims),
3032 [](
int64_t E) { return E == 1; }) &&
3033 all_of(shapecastShape.drop_back(numTrailingDims),
3034 [](
int64_t E) { return E == 1; }) &&
3035 "ill-formed shape_cast");
3037 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3041OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3042 if (getSourceType() == getResultVectorType())
3047 if (!adaptor.getSource())
3049 auto vectorType = getResultVectorType();
3050 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3051 if (vectorType.getElementType() != attr.getType())
3055 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3056 if (vectorType.getElementType() != attr.getType())
3060 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3062 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3070struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3073 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3074 PatternRewriter &rewriter)
const override {
3075 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3079 broadcastOp.getResultVectorType(),
3080 srcBroadcast.getSource());
3086void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3087 MLIRContext *context) {
3090 results.
add<BroadcastFolder>(context);
3097LogicalResult ShuffleOp::verify() {
3098 VectorType resultType = getResultVectorType();
3099 VectorType v1Type = getV1VectorType();
3100 VectorType v2Type = getV2VectorType();
3102 int64_t resRank = resultType.getRank();
3103 int64_t v1Rank = v1Type.getRank();
3104 int64_t v2Rank = v2Type.getRank();
3105 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3106 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3107 if (!wellFormed0DCase && !wellFormedNDCase)
3111 for (int64_t r = 1; r < v1Rank; ++r) {
3112 int64_t resDim = resultType.getDimSize(r);
3113 int64_t v1Dim = v1Type.getDimSize(r);
3114 int64_t v2Dim = v2Type.getDimSize(r);
3115 if (resDim != v1Dim || v1Dim != v2Dim)
3119 ArrayRef<int64_t> mask = getMask();
3120 int64_t maskLength = mask.size();
3121 if (maskLength <= 0)
3123 if (maskLength != resultType.getDimSize(0))
3126 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3127 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3128 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3130 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3136ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3137 ShuffleOp::Adaptor adaptor,
3138 SmallVectorImpl<Type> &inferredReturnTypes) {
3139 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3140 auto v1Rank = v1Type.getRank();
3143 SmallVector<int64_t, 4> shape;
3144 shape.reserve(v1Rank);
3145 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3148 llvm::append_range(shape, v1Type.getShape().drop_front());
3149 inferredReturnTypes.push_back(
3150 VectorType::get(shape, v1Type.getElementType()));
3154template <
typename T>
3157 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3158 return value == expected++;
3162OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3163 auto v1Type = getV1VectorType();
3164 auto v2Type = getV2VectorType();
3166 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3167 "Vector shuffle does not support scalable vectors");
3171 if (v1Type.getRank() == 0)
3175 auto mask = getMask();
3182 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3183 if (!v1Attr || !v2Attr)
3187 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3188 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3189 if (isV1Poison && isV2Poison)
3194 if (v1Type.getRank() != 1)
3200 SmallVector<Attribute> v1Elements, v2Elements;
3201 Attribute poisonElement;
3203 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3206 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3207 poisonElement = v2Elements[0];
3210 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3213 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3214 poisonElement = v1Elements[0];
3217 SmallVector<Attribute> results;
3218 int64_t v1Size = v1Type.getDimSize(0);
3219 for (int64_t maskIdx : mask) {
3220 Attribute indexedElm;
3222 if (maskIdx == ShuffleOp::kPoisonIndex) {
3223 indexedElm = poisonElement;
3225 if (maskIdx < v1Size)
3226 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3228 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3231 results.push_back(indexedElm);
3241struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3244 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3245 PatternRewriter &rewriter)
const override {
3246 VectorType v1VectorType = shuffleOp.getV1VectorType();
3247 ArrayRef<int64_t> mask = shuffleOp.getMask();
3248 if (v1VectorType.getRank() > 0)
3250 if (mask.size() != 1)
3252 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3270static Value getScalarSplatSource(Value value) {
3276 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3283 if (isa<VectorType>(
broadcast.getSourceType()))
3291class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3295 LogicalResult matchAndRewrite(ShuffleOp op,
3296 PatternRewriter &rewriter)
const override {
3297 Value splat = getScalarSplatSource(op.getV1());
3298 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3308class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3312 LogicalResult matchAndRewrite(ShuffleOp op,
3313 PatternRewriter &rewriter)
const override {
3314 VectorType resultType = op.getResultVectorType();
3315 if (resultType.isScalable())
3317 op,
"ShuffleOp can't represent a scalable interleave");
3319 if (resultType.getRank() != 1)
3321 op,
"ShuffleOp can't represent an n-D interleave");
3323 VectorType sourceType = op.getV1VectorType();
3324 if (sourceType != op.getV2VectorType() ||
3325 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3327 op,
"ShuffleOp types don't match an interleave");
3330 ArrayRef<int64_t> shuffleMask = op.getMask();
3331 int64_t resultVectorSize = resultType.getNumElements();
3332 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3333 int64_t maskValueA = shuffleMask[i * 2];
3334 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3335 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3337 "ShuffleOp mask not interleaving");
3347void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3348 MLIRContext *context) {
3349 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3357void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3359 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3362void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3363 Value source, Value dest) {
3364 auto vectorTy = cast<VectorType>(dest.
getType());
3365 build(builder,
result, source, dest,
3366 SmallVector<int64_t>(vectorTy.getRank(), 0));
3369void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3370 Value source, Value dest, int64_t position) {
3371 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3374void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3375 Value source, Value dest, OpFoldResult position) {
3376 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3379void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3380 Value source, Value dest,
3381 ArrayRef<int64_t> position) {
3382 SmallVector<OpFoldResult> posVals;
3383 posVals.reserve(position.size());
3384 llvm::transform(position, std::back_inserter(posVals),
3386 build(builder,
result, source, dest, posVals);
3389void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3390 Value source, Value dest,
3391 ArrayRef<OpFoldResult> position) {
3392 SmallVector<int64_t> staticPos;
3393 SmallVector<Value> dynamicPos;
3395 build(builder,
result, source, dest, dynamicPos,
3399LogicalResult InsertOp::verify() {
3400 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3401 if (srcTy.getRank() == 0)
3403 "expected a scalar instead of a 0-d vector as the source operand");
3405 SmallVector<OpFoldResult> position = getMixedPosition();
3406 auto destVectorType = getDestVectorType();
3407 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3409 "expected position attribute of rank no greater than dest vector rank");
3410 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3411 if (srcVectorType &&
3412 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3413 static_cast<unsigned>(destVectorType.getRank())))
3414 return emitOpError(
"expected position attribute rank + source rank to "
3415 "match dest vector rank");
3416 if (!srcVectorType &&
3417 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3419 "expected position attribute rank to match the dest vector rank");
3420 for (
auto [idx, pos] : llvm::enumerate(position)) {
3421 if (
auto attr = dyn_cast<Attribute>(pos)) {
3422 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3424 destVectorType.getDimSize(idx))) {
3425 return emitOpError(
"expected position attribute #")
3427 <<
" to be a non-negative integer smaller than the "
3429 "dest vector dimension";
3442 assert(positions.size() <= completePositions.size() &&
3443 "positions size must be less than or equal to destTy rank");
3444 copy(positions, completePositions.begin());
3452class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3456 LogicalResult matchAndRewrite(InsertOp insertOp,
3457 PatternRewriter &rewriter)
const override {
3459 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3460 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3461 srcVecType.getNumElements())
3464 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3470class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3474 LogicalResult matchAndRewrite(InsertOp op,
3475 PatternRewriter &rewriter)
const override {
3477 Value splat = getScalarSplatSource(op.getValueToStore());
3478 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3506class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3509 LogicalResult matchAndRewrite(InsertOp op,
3510 PatternRewriter &rewriter)
const override {
3512 VectorType destTy = op.getDestVectorType();
3513 if (destTy.isScalable())
3516 for (Operation *user : op.getResult().getUsers())
3517 if (
auto insertOp = dyn_cast<InsertOp>(user))
3518 if (insertOp.getDest() == op.getResult())
3521 InsertOp currentOp = op;
3522 SmallVector<InsertOp> chainInsertOps;
3525 if (currentOp.hasDynamicPosition())
3528 chainInsertOps.push_back(currentOp);
3529 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3532 if (currentOp && !currentOp->hasOneUse())
3536 int64_t vectorSize = destTy.getNumElements();
3537 int64_t initializedCount = 0;
3538 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3539 SmallVector<int64_t> pendingInsertPos;
3540 SmallVector<int64_t> pendingInsertSize;
3541 SmallVector<Value> pendingInsertValues;
3543 for (
auto insertOp : chainInsertOps) {
3545 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3549 int64_t insertBeginPosition =
3554 int64_t insertSize = 1;
3555 if (
auto srcVectorType =
3556 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3557 insertSize = srcVectorType.getNumElements();
3559 assert(insertBeginPosition + insertSize <= vectorSize &&
3560 "insert would overflow the vector");
3562 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3563 insertBeginPosition + insertSize)) {
3564 if (initializedDestIdxs[index])
3566 initializedDestIdxs[index] =
true;
3572 pendingInsertPos.push_back(insertBeginPosition);
3573 pendingInsertSize.push_back(insertSize);
3574 pendingInsertValues.push_back(insertOp.getValueToStore());
3576 if (initializedCount == vectorSize)
3581 if (initializedCount != vectorSize)
3584 SmallVector<Value> elements(vectorSize);
3585 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3586 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3587 pendingInsertValues))) {
3588 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3590 if (!srcVectorType) {
3591 elements[insertBeginPosition] = valueToStore;
3595 SmallVector<Type> elementToInsertTypes(insertSize,
3596 srcVectorType.getElementType());
3598 auto elementsToInsert = vector::ToElementsOp::create(
3599 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3600 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3601 elements[insertBeginPosition + linearIdx] =
3602 elementsToInsert.getResult(linearIdx);
3616 int64_t maxVectorSizeFoldThreshold) {
3617 if (insertOp.hasDynamicPosition())
3620 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3628 VectorType destTy = insertOp.getDestVectorType();
3629 if (destTy.isScalable())
3633 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3634 !insertOp->hasOneUse())
3641 Type destEltType = destTy.getElementType();
3645 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3646 for (
auto value : denseSource.getValues<
Attribute>())
3652 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3653 copy(insertedValues, allValues.begin() + insertBeginPosition);
3662 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3666 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3669 insertOp.
setOperand(1, destInsert.getDest());
3670 return insertOp.getResult();
3673void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3674 MLIRContext *context) {
3675 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3676 InsertChainFullyInitialized>(context);
3679OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3682 constexpr int64_t vectorSizeFoldThreshold = 256;
3686 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3687 return getValueToStore();
3691 SmallVector<Value> operands = {getValueToStore(), getDest()};
3697 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3700 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3701 vectorSizeFoldThreshold)) {
3705 return inplaceFolded;
3712void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
3713 Value source, Value dest,
3714 ArrayRef<int64_t> offsets,
3715 ArrayRef<int64_t> strides) {
3716 result.addOperands({source, dest});
3720 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
3722 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
3727template <
typename OpType>
3731 StringRef attrName) {
3732 if (arrayAttr.size() >
shape.size())
3733 return op.emitOpError(
"expected ")
3734 << attrName <<
" attribute of rank no greater than vector rank";
3741template <
typename OpType>
3745 bool halfOpen =
true) {
3746 for (
auto attr : arrayAttr) {
3747 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3751 if (val < min || val >= upper)
3752 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3753 <<
min <<
", " << upper <<
")";
3761template <
typename OpType>
3766 for (
auto [
index, attrDimPair] :
3767 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
3768 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3772 if (val < min || val >=
max)
3773 return op.emitOpError(
"expected ")
3774 << attrName <<
" dimension " <<
index <<
" to be confined to ["
3775 <<
min <<
", " <<
max <<
")";
3785template <
typename OpType>
3790 assert(arrayAttr1.size() <=
shape.size());
3791 assert(arrayAttr2.size() <=
shape.size());
3792 for (
auto [
index, it] :
3793 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
3794 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3795 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3799 if (val1 + val2 < 0 || val1 + val2 >=
max)
3800 return op.emitOpError(
"expected sum(")
3801 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
3802 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3810 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3812 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3815LogicalResult InsertStridedSliceOp::verify() {
3816 auto sourceVectorType = getSourceVectorType();
3817 auto destVectorType = getDestVectorType();
3818 auto offsets = getOffsetsAttr();
3819 auto strides = getStridesAttr();
3820 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3822 "expected offsets of same size as destination vector rank");
3823 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3824 return emitOpError(
"expected strides of same size as source vector rank");
3825 if (sourceVectorType.getRank() > destVectorType.getRank())
3827 "expected source rank to be no greater than destination rank");
3829 auto sourceShape = sourceVectorType.getShape();
3830 auto destShape = destVectorType.getShape();
3831 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3832 destShape.size() - sourceShape.size(), 0);
3833 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3834 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3835 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3844 offName,
"source vector shape",
3848 unsigned rankDiff = destShape.size() - sourceShape.size();
3849 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3850 if (sourceVectorType.getScalableDims()[idx] !=
3851 destVectorType.getScalableDims()[idx + rankDiff]) {
3852 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3855 if (sourceVectorType.getScalableDims()[idx]) {
3856 auto sourceSize = sourceShape[idx];
3857 auto destSize = destShape[idx + rankDiff];
3858 if (sourceSize != destSize) {
3861 << (
" to match the corresponding base size from the input "
3863 << sourceSize << (
" vs ") << destSize << (
")");
3873class FoldInsertStridedSliceSplat final
3874 :
public OpRewritePattern<InsertStridedSliceOp> {
3878 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3879 PatternRewriter &rewriter)
const override {
3881 auto dst = insertStridedSliceOp.getDest();
3882 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3883 if (!splat || getScalarSplatSource(dst) != splat)
3886 rewriter.
replaceOp(insertStridedSliceOp, dst);
3893class FoldInsertStridedSliceOfExtract final
3894 :
public OpRewritePattern<InsertStridedSliceOp> {
3898 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3899 PatternRewriter &rewriter)
const override {
3900 auto extractStridedSliceOp =
3901 insertStridedSliceOp.getValueToStore()
3902 .getDefiningOp<vector::ExtractStridedSliceOp>();
3904 if (!extractStridedSliceOp)
3907 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3911 if (extractStridedSliceOp.getStrides() !=
3912 insertStridedSliceOp.getStrides() ||
3913 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3916 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3923class InsertStridedSliceConstantFolder final
3924 :
public OpRewritePattern<InsertStridedSliceOp> {
3930 static constexpr int64_t vectorSizeFoldThreshold = 256;
3932 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3933 PatternRewriter &rewriter)
const override {
3937 Attribute vectorDestCst;
3941 VectorType destTy = destVector.getType();
3942 if (destTy.isScalable())
3946 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3947 !destVector.hasOneUse())
3951 Attribute sourceCst;
3956 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3960 if (op.hasNonUnitStrides())
3963 VectorType sliceVecTy = sourceValue.getType();
3964 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3965 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3966 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
3967 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
3975 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3976 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3977 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3978 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3979 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3980 MutableArrayRef<int64_t> currSlicePosition(
3981 currDestPosition.begin() + rankDifference, currDestPosition.end());
3982 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3985 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3986 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3987 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3988 "Invalid slice element");
3989 newValues[linearizedPosition] = *sliceValuesIt;
4002void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4003 RewritePatternSet &results, MLIRContext *context) {
4004 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4005 InsertStridedSliceConstantFolder>(context);
4008OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4009 if (getSourceVectorType() == getDestVectorType())
4010 return getValueToStore();
4019void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4020 Value
lhs, Value
rhs, Value acc) {
4025void OuterProductOp::print(OpAsmPrinter &p) {
4026 p <<
" " << getLhs() <<
", " << getRhs();
4028 p <<
", " << getAcc();
4031 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4034ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4035 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4042 if (operandsInfo.size() < 2)
4044 "expected at least 2 operands");
4045 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4046 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4049 "expected vector type for operand #1");
4053 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4054 vRHS.getScalableDims()[0]};
4055 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4056 vLHS.getElementType(), scalableDimsRes);
4059 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4060 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4064 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4065 result.attributes.append(
4066 OuterProductOp::getKindAttrName(
result.name),
4067 CombiningKindAttr::get(
result.getContext(),
4068 OuterProductOp::getDefaultKind()));
4074 (operandsInfo.size() > 2 &&
4079LogicalResult OuterProductOp::verify() {
4080 Type tRHS = getOperandTypeRHS();
4081 VectorType vLHS = getOperandVectorTypeLHS(),
4082 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4083 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4085 if (vLHS.getRank() != 1)
4086 return emitOpError(
"expected 1-d vector for operand #1");
4090 if (vRHS.getRank() != 1)
4091 return emitOpError(
"expected 1-d vector for operand #2");
4092 if (vRES.getRank() != 2)
4094 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4095 return emitOpError(
"expected #1 operand dim to match result dim #1");
4096 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4097 return emitOpError(
"expected #2 operand dim to match result dim #2");
4098 if (vLHS.isScalable() && !vRHS.isScalable()) {
4102 "expected either both or only #2 operand dim to be scalable");
4106 if (vRES.getRank() != 1)
4108 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4109 return emitOpError(
"expected #1 operand dim to match result dim #1");
4112 if (vACC && vACC != vRES)
4113 return emitOpError(
"expected operand #3 of same type as result type");
4115 if (!getKindAttr()) {
4116 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4117 "'vector.kind<add>')");
4122 return emitOpError(
"unsupported outerproduct type");
4131Type OuterProductOp::getExpectedMaskType() {
4132 auto vecType = this->getResultVectorType();
4133 return VectorType::get(vecType.getShape(),
4134 IntegerType::get(vecType.getContext(), 1),
4135 vecType.getScalableDims());
4149 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4151 shape.reserve(vectorType.getRank());
4153 for (
unsigned e = offsets.size(); idx < e; ++idx)
4154 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4155 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4156 shape.push_back(vectorType.getShape()[idx]);
4158 return VectorType::get(
shape, vectorType.getElementType(),
4159 vectorType.getScalableDims());
4162void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4163 Value source, ArrayRef<int64_t> offsets,
4164 ArrayRef<int64_t> sizes,
4165 ArrayRef<int64_t> strides) {
4166 result.addOperands(source);
4172 offsetsAttr, sizesAttr, stridesAttr));
4173 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4175 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4177 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4181LogicalResult ExtractStridedSliceOp::verify() {
4182 auto type = getSourceVectorType();
4183 auto offsets = getOffsetsAttr();
4184 auto sizes = getSizesAttr();
4185 auto strides = getStridesAttr();
4186 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4188 "expected offsets, sizes and strides attributes of same size");
4190 auto shape = type.getShape();
4191 auto offName = getOffsetsAttrName();
4192 auto sizesName = getSizesAttrName();
4193 auto stridesName = getStridesAttrName();
4209 shape, offName, sizesName,
4214 offsets, sizes, strides);
4215 if (getResult().
getType() != resultType)
4216 return emitOpError(
"expected result type to be ") << resultType;
4218 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4219 if (type.getScalableDims()[idx]) {
4220 auto inputDim = type.getShape()[idx];
4221 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4222 if (inputDim != inputSize)
4225 << (
" to match the corresponding base size from the input "
4227 << inputSize << (
" vs ") << inputDim << (
")");
4240 auto getElement = [](
ArrayAttr array,
int idx) {
4241 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4243 ArrayAttr extractOffsets = op.getOffsets();
4246 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4248 if (op.getSourceVectorType().getRank() !=
4249 insertOp.getSourceVectorType().getRank())
4251 ArrayAttr insertOffsets = insertOp.getOffsets();
4252 ArrayAttr insertStrides = insertOp.getStrides();
4255 if (extractOffsets.size() > insertOffsets.size())
4257 bool patialoverlap =
false;
4258 bool disjoint =
false;
4260 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4261 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4263 int64_t start = getElement(insertOffsets, dim);
4264 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4265 int64_t offset = getElement(extractOffsets, dim);
4266 int64_t size = getElement(extractSizes, dim);
4268 if (start <= offset && offset < end) {
4271 if (offset + size > end)
4272 patialoverlap =
true;
4273 offsetDiffs.push_back(offset - start);
4280 if (!disjoint && !patialoverlap) {
4281 op.setOperand(insertOp.getValueToStore());
4284 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4290 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4305 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4310 if (op.hasNonUnitStrides())
4313 VectorType sourceVecTy = op.getSourceVectorType();
4317 VectorType sliceVecTy = op.getType();
4319 int64_t rank = sliceVecTy.getRank();
4331 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4333 sliceValues.reserve(sliceVecTy.getNumElements());
4337 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4339 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4340 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4342 assert(
static_cast<int64_t>(sliceValues.size()) ==
4343 sliceVecTy.getNumElements() &&
4344 "Invalid number of slice elements");
4348OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4349 if (getSourceVectorType() == getResult().
getType())
4356 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4363void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4385class StridedSliceFolder final
4386 :
public OpRewritePattern<ExtractStridedSliceOp> {
4388 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4390 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4391 PatternRewriter &rewriter)
const override {
4392 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4396 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4399 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4400 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4401 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4402 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4404 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4405 SmallVector<int64_t> combinedOffsets(newRank, 0);
4406 SmallVector<int64_t> combinedSizes(newRank);
4407 ArrayRef<int64_t> firstSourceShape =
4408 firstOp.getSourceVectorType().getShape();
4409 for (
unsigned i = 0; i < newRank; ++i) {
4410 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4411 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4412 combinedOffsets[i] = off1 + off2;
4414 if (i < secondSizes.size()) {
4415 combinedSizes[i] = secondSizes[i];
4416 }
else if (i < firstSizes.size()) {
4417 combinedSizes[i] = firstSizes[i];
4419 combinedSizes[i] = firstSourceShape[i];
4423 SmallVector<int64_t> combinedStrides(newRank, 1);
4425 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4443class StridedSliceCreateMaskFolder final
4444 :
public OpRewritePattern<ExtractStridedSliceOp> {
4448 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4449 PatternRewriter &rewriter)
const override {
4450 Location loc = extractStridedSliceOp.getLoc();
4454 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4458 if (extractStridedSliceOp.hasNonUnitStrides())
4461 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4463 SmallVector<int64_t> sliceOffsets;
4466 SmallVector<int64_t> sliceSizes;
4470 SmallVector<Value> sliceMaskDimSizes;
4471 sliceMaskDimSizes.reserve(maskDimSizes.size());
4475 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4476 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4480 IntegerAttr offsetAttr =
4482 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4483 Value sliceMaskDimSize =
4484 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4485 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4490 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4494 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4502class StridedSliceConstantMaskFolder final
4503 :
public OpRewritePattern<ExtractStridedSliceOp> {
4507 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4508 PatternRewriter &rewriter)
const override {
4511 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4512 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4513 if (!constantMaskOp)
4516 if (extractStridedSliceOp.hasNonUnitStrides())
4519 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4521 SmallVector<int64_t> sliceOffsets;
4524 SmallVector<int64_t> sliceSizes;
4528 SmallVector<int64_t> sliceMaskDimSizes;
4529 sliceMaskDimSizes.reserve(maskDimSizes.size());
4530 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4531 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4532 int64_t sliceMaskDimSize = std::max(
4533 static_cast<int64_t
>(0),
4534 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4535 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4538 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4539 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4540 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4543 if (llvm::is_contained(sliceMaskDimSizes, 0))
4544 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4549 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4557class StridedSliceBroadcast final
4558 :
public OpRewritePattern<ExtractStridedSliceOp> {
4562 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4563 PatternRewriter &rewriter)
const override {
4569 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4570 auto dstVecType = llvm::cast<VectorType>(op.getType());
4571 unsigned dstRank = dstVecType.getRank();
4572 unsigned rankDiff = dstRank - srcRank;
4576 bool needsSlice =
false;
4577 for (
unsigned i = 0; i < srcRank; i++) {
4578 if (srcVecType.getDimSize(i) != 1 &&
4579 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4586 SmallVector<int64_t> offsets =
4588 SmallVector<int64_t> sizes =
4590 for (
unsigned i = 0; i < srcRank; i++) {
4591 if (srcVecType.getDimSize(i) == 1) {
4599 source = ExtractStridedSliceOp::create(
4600 rewriter, op->getLoc(), source, offsets, sizes,
4609class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4613 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4614 PatternRewriter &rewriter)
const override {
4616 Value splat = getScalarSplatSource(op.getSource());
4640class ContiguousExtractStridedSliceToExtract final
4641 :
public OpRewritePattern<ExtractStridedSliceOp> {
4645 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4646 PatternRewriter &rewriter)
const override {
4647 if (op.hasNonUnitStrides())
4649 Value source = op.getOperand();
4650 auto sourceType = cast<VectorType>(source.
getType());
4651 if (sourceType.isScalable() || sourceType.getRank() == 0)
4660 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4661 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4668 if (numOffsets == 0)
4673 if (numOffsets == sourceType.getRank() &&
4674 static_cast<int>(sizes.size()) == sourceType.getRank())
4678 for (
int i = 0; i < numOffsets; ++i) {
4686 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4687 sizes[numOffsets] == 1) {
4692 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4693 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4702void ExtractStridedSliceOp::getCanonicalizationPatterns(
4703 RewritePatternSet &results, MLIRContext *context) {
4706 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4707 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4708 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4717void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4718 VectorType vectorType, Value source,
4720 AffineMapAttr permutationMapAttr,
4723 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4725 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4726 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4727 *padding, Value(), inBoundsAttr);
4731void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4732 VectorType vectorType, Value source,
4734 AffineMap permutationMap,
4735 std::optional<ArrayRef<bool>> inBounds) {
4736 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4737 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4740 SmallVector<bool>(vectorType.getRank(),
false));
4741 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4743 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4744 build(builder,
result, vectorType, source,
indices, *padding,
4745 permutationMapAttr, inBoundsAttr);
4749void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4750 VectorType vectorType, Value source,
4752 std::optional<ArrayRef<bool>> inBounds) {
4754 llvm::cast<ShapedType>(source.
getType()), vectorType);
4755 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4756 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4759 SmallVector<bool>(vectorType.getRank(),
false));
4760 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4762 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4763 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4765 Value(), inBoundsAttr);
4768template <
typename EmitFun>
4772 for (
auto expr : permutationMap.
getResults()) {
4773 auto dim = dyn_cast<AffineDimExpr>(expr);
4774 auto zero = dyn_cast<AffineConstantExpr>(expr);
4776 if (zero.getValue() != 0) {
4778 "requires a projected permutation_map (at most one dim or the zero "
4779 "constant can appear in each result)");
4784 return emitOpError(
"requires a projected permutation_map (at most one "
4785 "dim or the zero constant can appear in each result)");
4787 if (seen[dim.getPosition()]) {
4789 "requires a permutation_map that is a permutation (found one dim "
4790 "used more than once)");
4792 seen[dim.getPosition()] =
true;
4799 VectorType vectorType, VectorType maskType,
4800 VectorType inferredMaskType,
AffineMap permutationMap,
4802 if (op->hasAttr(
"masked")) {
4803 return op->emitOpError(
"masked attribute has been removed. "
4804 "Use in_bounds instead.");
4807 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4808 return op->emitOpError(
4809 "requires source to be a memref or ranked tensor type");
4811 auto elementType = shapedType.getElementType();
4813 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4815 unsigned sourceVecSize =
4817 vectorElementType.getShape().back();
4818 unsigned resultVecSize =
4820 vectorType.getShape().back();
4821 if (resultVecSize % sourceVecSize != 0)
4822 return op->emitOpError(
4823 "requires the bitwidth of the minor 1-D vector to be an integral "
4824 "multiple of the bitwidth of the minor 1-D vector of the source");
4826 unsigned sourceVecEltRank = vectorElementType.getRank();
4827 unsigned resultVecRank = vectorType.getRank();
4828 if (sourceVecEltRank > resultVecRank)
4829 return op->emitOpError(
4830 "requires source vector element and vector result ranks to match.");
4831 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4834 return op->emitOpError(
"requires a permutation_map with result dims of "
4835 "the same rank as the vector type");
4838 return op->emitOpError(
"does not support masks with vector element type");
4841 unsigned minorSize =
4842 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4843 unsigned resultVecSize =
4846 return op->emitOpError(
4847 "requires the bitwidth of the minor 1-D vector to be an integral "
4848 "multiple of the bitwidth of the source element type");
4852 return op->emitOpError(
"requires a permutation_map with result dims of "
4853 "the same rank as the vector type");
4857 return op->emitOpError(
"requires permutation_map without symbols");
4859 if (permutationMap.
getNumInputs() != shapedType.getRank())
4860 return op->emitOpError(
"requires a permutation_map with input dims of the "
4861 "same rank as the source type");
4863 if (maskType && maskType != inferredMaskType)
4864 return op->emitOpError(
"inferred mask type (")
4865 << inferredMaskType <<
") and mask operand type (" << maskType
4869 return op->emitOpError(
"expects the in_bounds attr of same rank "
4870 "as permutation_map results: ")
4871 << AffineMapAttr::get(permutationMap)
4872 <<
" vs inBounds of size: " << inBounds.size();
4879 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4880 if (op.getPermutationMap().isMinorIdentity())
4881 elidedAttrs.push_back(op.getPermutationMapAttrName());
4883 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4884 elidedAttrs.push_back(op.getInBoundsAttrName());
4888void TransferReadOp::print(OpAsmPrinter &p) {
4891 p <<
", " << getMask();
4898 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
4900 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4905 if (maskShape.empty())
4906 maskShape.push_back(1);
4911 return VectorType::get(maskShape, i1Type, scalableDims);
4928 if (hasMask.succeeded()) {
4935 if (types.size() != 2)
4936 return parser.
emitError(typesLoc,
"requires two types");
4938 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4939 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4940 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4941 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4943 return parser.
emitError(typesLoc,
"requires vector type");
4944 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
4948 if (shapedType.getRank() <
4951 "expected a custom permutation_map when "
4952 "rank(source) != rank(destination)");
4954 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4956 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4958 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
4959 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
4960 if (!inBoundsAttr) {
4961 result.addAttribute(inBoundsAttrName,
4970 if (hasMask.succeeded()) {
4971 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4973 maskInfo.
location,
"does not support masks with vector element type");
4976 "expected the same rank for the vector and the "
4977 "results of the permutation map");
4985 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4987 {1, static_cast<int32_t>(indexInfo.size()), 1,
4988 static_cast<int32_t>(hasMask.succeeded())}));
4992LogicalResult TransferReadOp::verify() {
4994 ShapedType shapedType = getShapedType();
4996 VectorType maskType = getMaskType();
4997 auto paddingType = getPadding().getType();
4998 auto permutationMap = getPermutationMap();
4999 VectorType inferredMaskType =
5002 auto sourceElementType = shapedType.getElementType();
5004 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5005 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5008 shapedType, vectorType, maskType,
5009 inferredMaskType, permutationMap, getInBounds())))
5012 if (
auto sourceVectorElementType =
5013 llvm::dyn_cast<VectorType>(sourceElementType)) {
5016 if (sourceVectorElementType != paddingType)
5018 "requires source element type and padding type to match.");
5022 if (!VectorType::isValidElementType(paddingType))
5023 return emitOpError(
"requires valid padding vector elemental type");
5026 if (paddingType != sourceElementType)
5028 "requires formal padding and source of the same elemental type");
5039Type TransferReadOp::getExpectedMaskType() {
5046VectorType TransferReadOp::getVectorType() {
5047 return cast<VectorType>(getVector().
getType());
5050template <
typename TransferOp>
5054 if (op.getShapedType().isDynamicDim(indicesIdx))
5058 if (!cstOp.has_value())
5061 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5062 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5064 return cstOp.value() + vectorSize <= sourceSize;
5067template <
typename TransferOp>
5071 if (op.getTransferRank() == 0)
5076 newInBounds.reserve(op.getTransferRank());
5081 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5083 if (op.isDimInBounds(i)) {
5084 newInBounds.push_back(
true);
5089 bool inBounds =
false;
5090 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5093 dimExpr.getPosition());
5094 nonBcastDims.push_back(i);
5097 newInBounds.push_back(inBounds);
5105 bool allNonBcastDimsInBounds = llvm::all_of(
5106 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5107 if (allNonBcastDimsInBounds) {
5110 newInBounds[idx] =
true;
5118 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5122template <
typename TransferOp>
5124 auto mask = op.getMask();
5131 op.getMaskMutable().clear();
5145static Value foldRAW(TransferReadOp readOp) {
5146 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5148 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5151 return defWrite.getVector();
5153 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5154 cast<VectorTransferOpInterface>(readOp.getOperation())))
5156 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5161OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5162 if (Value vec = foldRAW(*
this))
5173 return OpFoldResult();
5176std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5180void TransferReadOp::getEffects(
5181 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5183 if (llvm::isa<MemRefType>(getShapedType()))
5184 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5185 SideEffects::DefaultResource::get());
5189 if (hasPureTensorSemantics())
5196static AffineMap inverseWithUnusedDims(AffineMap map) {
5198 "expected a projected permutation map");
5203 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5233struct TransferReadAfterWriteToBroadcast
5234 :
public OpRewritePattern<TransferReadOp> {
5237 LogicalResult matchAndRewrite(TransferReadOp readOp,
5238 PatternRewriter &rewriter)
const override {
5239 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5243 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5247 if (readOp.getMask() || defWrite.getMask())
5250 if (readOp.getIndices() != defWrite.getIndices())
5253 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5257 if (readOp.getTransferChunkAccessed() !=
5258 defWrite.getTransferChunkAccessed())
5265 AffineMap readMap = readOp.getPermutationMap();
5266 AffineMap writeMap = defWrite.getPermutationMap();
5267 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5268 AffineMap composedMap = readMap.
compose(invWriteMap);
5282 int64_t numBroadcastedDims = broadcastedDims.size();
5283 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5285 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5286 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5287 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5288 invPerm[effectiveDim] = idx;
5293 VectorType readVecTy = readOp.getVectorType();
5295 auto broadcastedVecTy =
5297 readVecTy.getElementType(),
5300 Value vec = defWrite.getVector();
5301 Location loc = readOp.getLoc();
5302 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5309void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5310 MLIRContext *context) {
5311 results.
add<TransferReadAfterWriteToBroadcast>(context);
5314FailureOr<std::optional<SmallVector<Value>>>
5315TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5316 if (!hasPureBufferSemantics())
5327void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5329 AffineMapAttr permutationMapAttr,
5332 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5333 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5334 mask, inBoundsAttr);
5338void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5340 AffineMapAttr permutationMapAttr,
5342 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5343 Value(), inBoundsAttr);
5348void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5350 AffineMap permutationMap,
5351 std::optional<ArrayRef<bool>> inBounds) {
5352 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5354 (inBounds && !inBounds.value().empty())
5357 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5358 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5359 Value(), inBoundsAttr);
5364void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5366 std::optional<ArrayRef<bool>> inBounds) {
5367 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5369 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5370 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5373ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5374 OperationState &
result) {
5377 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5378 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5379 SmallVector<Type, 2> types;
5380 OpAsmParser::UnresolvedOperand maskInfo;
5386 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5391 if (types.size() != 2)
5392 return parser.
emitError(typesLoc,
"requires two types");
5394 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5396 return parser.
emitError(typesLoc,
"requires vector type");
5397 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5398 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5399 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5400 auto permMapAttrName =
5401 TransferWriteOp::getPermutationMapAttrName(
result.name);
5402 auto permMapAttr =
result.attributes.get(permMapAttrName);
5405 if (shapedType.getRank() <
5408 "expected a custom permutation_map when "
5409 "rank(source) != rank(destination)");
5411 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5413 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5415 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5416 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5417 if (!inBoundsAttr) {
5418 result.addAttribute(inBoundsAttrName,
5426 if (hasMask.succeeded()) {
5427 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5429 maskInfo.
location,
"does not support masks with vector element type");
5432 "expected the same rank for the vector and the "
5433 "results of the permutation map");
5439 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5441 {1, 1, static_cast<int32_t>(indexInfo.size()),
5442 static_cast<int32_t>(hasMask.succeeded())}));
5443 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5447void TransferWriteOp::print(OpAsmPrinter &p) {
5450 p <<
", " << getMask();
5455LogicalResult TransferWriteOp::verify() {
5457 ShapedType shapedType = getShapedType();
5459 VectorType maskType = getMaskType();
5460 auto permutationMap = getPermutationMap();
5461 VectorType inferredMaskType =
5465 if (llvm::size(
getIndices()) != shapedType.getRank())
5466 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5470 if (hasBroadcastDim())
5471 return emitOpError(
"should not have broadcast dimensions");
5474 shapedType, vectorType, maskType,
5475 inferredMaskType, permutationMap, getInBounds())))
5488Type TransferWriteOp::getExpectedMaskType() {
5495Value TransferWriteOp::getVector() {
return getOperand(0); }
5496VectorType TransferWriteOp::getVectorType() {
5497 return cast<VectorType>(getValueToStore().
getType());
5520static LogicalResult foldReadInitWrite(TransferWriteOp write,
5521 ArrayRef<Attribute>,
5522 SmallVectorImpl<OpFoldResult> &results) {
5524 if (write.getTransferRank() == 0)
5526 auto rankedTensorType =
5527 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5529 if (!rankedTensorType)
5532 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5536 if (read.getTransferRank() == 0)
5539 if (!read.getPermutationMap().isMinorIdentity() ||
5540 !write.getPermutationMap().isMinorIdentity())
5543 if (read.getTransferRank() != write.getTransferRank())
5546 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5549 if (read.getBase().getType() != rankedTensorType)
5552 if (read.getVectorType() != write.getVectorType())
5555 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5558 auto isNotConstantZero = [](Value v) {
5560 return !cstOp.has_value() || cstOp.value() != 0;
5562 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5563 llvm::any_of(write.getIndices(), isNotConstantZero))
5566 results.push_back(read.getBase());
5570static bool checkSameValueWAR(vector::TransferReadOp read,
5571 vector::TransferWriteOp write) {
5572 return read.getBase() == write.getBase() &&
5573 read.getIndices() == write.getIndices() &&
5574 read.getPermutationMap() == write.getPermutationMap() &&
5575 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5592static LogicalResult foldWAR(TransferWriteOp write,
5593 SmallVectorImpl<OpFoldResult> &results) {
5594 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5596 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5600 if (!checkSameValueWAR(read, write))
5602 results.push_back(read.getBase());
5606LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5607 SmallVectorImpl<OpFoldResult> &results) {
5608 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5610 if (succeeded(foldWAR(*
this, results)))
5622std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5626void TransferWriteOp::getEffects(
5627 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5629 if (llvm::isa<MemRefType>(getShapedType()))
5630 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5631 SideEffects::DefaultResource::get());
5635 if (hasPureTensorSemantics())
5665class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
5668 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5669 PatternRewriter &rewriter)
const override {
5670 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5672 vector::TransferWriteOp writeToModify = writeOp;
5674 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5678 writeToModify.getBaseMutable().assign(defWrite.getBase());
5683 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5684 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5688 if (!defWrite->hasOneUse())
5690 writeToModify = defWrite;
5691 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5720struct SwapExtractSliceOfTransferWrite
5721 :
public OpRewritePattern<tensor::InsertSliceOp> {
5725 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5726 PatternRewriter &rewriter)
const override {
5727 if (!insertOp.hasUnitStride())
5730 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5731 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5733 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5734 if (!transferOp || !transferOp->hasOneUse())
5739 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5741 "use-def chain is rank-reducing");
5745 if (!extractOp.hasZeroOffset()) {
5747 "ExtractSliceOp has non-zero offset");
5751 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5752 return getConstantIntValue(value) == static_cast<int64_t>(0);
5755 "TranferWriteOp has non-zero offset");
5759 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5761 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5764 for (
auto [insertSize, extractSize] :
5765 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5768 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5773 assert(transferOp.getVectorType().hasStaticShape() &&
5774 "expected vector to have a static shape");
5775 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
5777 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5778 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5780 insertOp,
"TransferWriteOp may not write the full tensor.");
5785 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
5786 auto newExtractOp = tensor::ExtractSliceOp::create(
5787 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5788 insertOp.getDest(), insertOp.getMixedOffsets(),
5789 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5790 auto newTransferWriteOp = TransferWriteOp::create(
5791 rewriter, transferOp.getLoc(), transferOp.getVector(),
5792 newExtractOp.getResult(), transferOp.getIndices(),
5793 transferOp.getPermutationMapAttr(),
5796 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5804void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5805 MLIRContext *context) {
5806 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5809FailureOr<std::optional<SmallVector<Value>>>
5810TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5811 if (!hasPureBufferSemantics())
5821static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5823 MemRefType memRefTy) {
5826 if (!vecTy.isScalable() &&
5827 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5830 if (!memRefTy.isLastDimUnitStride())
5831 return op->
emitOpError(
"most minor memref dim must have unit stride");
5835LogicalResult vector::LoadOp::verify() {
5839 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5842 if (memRefTy.getRank() < resVecTy.getRank())
5844 "destination memref has lower rank than the result vector");
5847 Type memElemTy = memRefTy.getElementType();
5848 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5849 if (memVecTy != resVecTy)
5850 return emitOpError(
"base memref and result vector types should match");
5851 memElemTy = memVecTy.getElementType();
5854 if (resVecTy.getElementType() != memElemTy)
5855 return emitOpError(
"base and result element types should match");
5856 if (llvm::size(
getIndices()) != memRefTy.getRank())
5857 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5861OpFoldResult LoadOp::fold(FoldAdaptor) {
5864 return OpFoldResult();
5867std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5871FailureOr<std::optional<SmallVector<Value>>>
5872LoadOp::bubbleDownCasts(OpBuilder &builder) {
5881LogicalResult vector::StoreOp::verify() {
5885 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5888 if (memRefTy.getRank() < valueVecTy.getRank())
5889 return emitOpError(
"source memref has lower rank than the vector to store");
5892 Type memElemTy = memRefTy.getElementType();
5893 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5894 if (memVecTy != valueVecTy)
5896 "base memref and valueToStore vector types should match");
5897 memElemTy = memVecTy.getElementType();
5900 if (valueVecTy.getElementType() != memElemTy)
5901 return emitOpError(
"base and valueToStore element type should match");
5902 if (llvm::size(
getIndices()) != memRefTy.getRank())
5903 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5907LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5908 SmallVectorImpl<OpFoldResult> &results) {
5912std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5916FailureOr<std::optional<SmallVector<Value>>>
5917StoreOp::bubbleDownCasts(OpBuilder &builder) {
5926LogicalResult MaskedLoadOp::verify() {
5927 VectorType maskVType = getMaskVectorType();
5928 VectorType passVType = getPassThruVectorType();
5932 if (resVType.getElementType() != memType.getElementType())
5933 return emitOpError(
"base and result element type should match");
5934 if (llvm::size(
getIndices()) != memType.getRank())
5935 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5936 if (resVType.getShape() != maskVType.getShape())
5937 return emitOpError(
"expected result shape to match mask shape");
5938 if (resVType != passVType)
5939 return emitOpError(
"expected pass_thru of same type as result type");
5944class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
5947 LogicalResult matchAndRewrite(MaskedLoadOp
load,
5948 PatternRewriter &rewriter)
const override {
5960 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5965void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5966 MLIRContext *context) {
5967 results.
add<MaskedLoadFolder>(context);
5970OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5973 return OpFoldResult();
5976FailureOr<std::optional<SmallVector<Value>>>
5977MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
5986LogicalResult MaskedStoreOp::verify() {
5987 VectorType maskVType = getMaskVectorType();
5991 if (valueVType.getElementType() != memType.getElementType())
5992 return emitOpError(
"base and valueToStore element type should match");
5993 if (llvm::size(
getIndices()) != memType.getRank())
5994 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5995 if (valueVType.getShape() != maskVType.getShape())
5996 return emitOpError(
"expected valueToStore shape to match mask shape");
6001class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6004 LogicalResult matchAndRewrite(MaskedStoreOp store,
6005 PatternRewriter &rewriter)
const override {
6009 store, store.getValueToStore(), store.getBase(), store.getIndices());
6017 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6022void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6023 MLIRContext *context) {
6024 results.
add<MaskedStoreFolder>(context);
6027LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6028 SmallVectorImpl<OpFoldResult> &results) {
6032FailureOr<std::optional<SmallVector<Value>>>
6033MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6042LogicalResult GatherOp::verify() {
6043 VectorType indVType = getIndexVectorType();
6044 VectorType maskVType = getMaskVectorType();
6046 ShapedType baseType = getBaseType();
6048 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6049 return emitOpError(
"requires base to be a memref or ranked tensor type");
6051 if (resVType.getElementType() != baseType.getElementType())
6052 return emitOpError(
"base and result element type should match");
6053 if (llvm::size(getOffsets()) != baseType.getRank())
6054 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6055 if (resVType.getShape() != indVType.getShape())
6056 return emitOpError(
"expected result dim to match indices dim");
6057 if (resVType.getShape() != maskVType.getShape())
6058 return emitOpError(
"expected result dim to match mask dim");
6059 if (resVType != getPassThruVectorType())
6060 return emitOpError(
"expected pass_thru of same type as result type");
6068Type GatherOp::getExpectedMaskType() {
6069 auto vecType = this->getIndexVectorType();
6070 return VectorType::get(vecType.getShape(),
6071 IntegerType::get(vecType.getContext(), 1),
6072 vecType.getScalableDims());
6075std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6080static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6081 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6082 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6088 DenseIntElementsAttr elements;
6093 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6097class GatherFolder final :
public OpRewritePattern<GatherOp> {
6100 LogicalResult matchAndRewrite(GatherOp gather,
6101 PatternRewriter &rewriter)
const override {
6106 rewriter.
replaceOp(gather, gather.getPassThru());
6111 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6117class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6120 LogicalResult matchAndRewrite(GatherOp op,
6121 PatternRewriter &rewriter)
const override {
6122 if (!isa<MemRefType>(op.getBase().getType()))
6125 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6129 op.getOffsets(), op.getMask(),
6136void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6137 MLIRContext *context) {
6138 results.
add<GatherFolder, FoldContiguousGather>(context);
6141FailureOr<std::optional<SmallVector<Value>>>
6142GatherOp::bubbleDownCasts(OpBuilder &builder) {
6151LogicalResult ScatterOp::verify() {
6152 VectorType indVType = getIndexVectorType();
6153 VectorType maskVType = getMaskVectorType();
6155 ShapedType baseType = getBaseType();
6157 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6158 return emitOpError(
"requires base to be a memref or ranked tensor type");
6160 if (valueVType.getElementType() != baseType.getElementType())
6161 return emitOpError(
"base and valueToStore element type should match");
6162 if (llvm::size(getOffsets()) != baseType.getRank())
6163 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6164 if (valueVType.getShape() != indVType.getShape())
6165 return emitOpError(
"expected valueToStore dim to match indices dim");
6166 if (valueVType.getShape() != maskVType.getShape())
6167 return emitOpError(
"expected valueToStore dim to match mask dim");
6171class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6174 LogicalResult matchAndRewrite(ScatterOp scatter,
6175 PatternRewriter &rewriter)
const override {
6176 ShapedType baseType = scatter.getBaseType();
6177 bool isMemRef = isa<MemRefType>(baseType);
6178 if (!isMemRef && !isa<RankedTensorType>(baseType))
6191 rewriter.
replaceOp(scatter, scatter.getBase());
6196 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6202class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6205 LogicalResult matchAndRewrite(ScatterOp op,
6206 PatternRewriter &rewriter)
const override {
6209 if (!isa<MemRefType>(op.getBase().getType()))
6212 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6216 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6222void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6223 MLIRContext *context) {
6224 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6227FailureOr<std::optional<SmallVector<Value>>>
6228ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6237LogicalResult ExpandLoadOp::verify() {
6238 VectorType maskVType = getMaskVectorType();
6239 VectorType passVType = getPassThruVectorType();
6243 if (resVType.getElementType() != memType.getElementType())
6244 return emitOpError(
"base and result element type should match");
6245 if (llvm::size(
getIndices()) != memType.getRank())
6246 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6247 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6248 return emitOpError(
"expected result dim to match mask dim");
6249 if (resVType != passVType)
6250 return emitOpError(
"expected pass_thru of same type as result type");
6255class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6258 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6259 PatternRewriter &rewriter)
const override {
6263 expand, expand.getType(), expand.getBase(), expand.getIndices());
6266 rewriter.
replaceOp(expand, expand.getPassThru());
6271 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6276void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6277 MLIRContext *context) {
6278 results.
add<ExpandLoadFolder>(context);
6281FailureOr<std::optional<SmallVector<Value>>>
6282ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6291LogicalResult CompressStoreOp::verify() {
6292 VectorType maskVType = getMaskVectorType();
6296 if (valueVType.getElementType() != memType.getElementType())
6297 return emitOpError(
"base and valueToStore element type should match");
6298 if (llvm::size(
getIndices()) != memType.getRank())
6299 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6300 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6301 return emitOpError(
"expected valueToStore dim to match mask dim");
6306class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6309 LogicalResult matchAndRewrite(CompressStoreOp compress,
6310 PatternRewriter &rewriter)
const override {
6314 compress, compress.getValueToStore(), compress.getBase(),
6315 compress.getIndices());
6323 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6328void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6329 MLIRContext *context) {
6330 results.
add<CompressStoreFolder>(context);
6333FailureOr<std::optional<SmallVector<Value>>>
6334CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6343void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6345 setResultRanges(getResult(), argRanges.front());
6348std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6349 return llvm::to_vector<4>(getResultVectorType().
getShape());
6352LogicalResult ShapeCastOp::verify() {
6354 VectorType sourceType = getSourceVectorType();
6355 VectorType resultType = getResultVectorType();
6358 if (sourceType.getElementType() != resultType.getElementType())
6359 return emitOpError(
"has different source and result element types");
6362 int64_t sourceNElms = sourceType.getNumElements();
6363 int64_t resultNElms = resultType.getNumElements();
6364 if (sourceNElms != resultNElms) {
6365 return emitOpError() <<
"has different number of elements at source ("
6366 << sourceNElms <<
") and result (" << resultNElms
6371 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6372 int64_t resultNScalableDims = resultType.getNumScalableDims();
6373 if (sourceNScalableDims != resultNScalableDims)
6374 return emitOpError() <<
"has different number of scalable dims at source ("
6375 << sourceNScalableDims <<
") and result ("
6376 << resultNScalableDims <<
")";
6385static bool isOrderPreserving(TransposeOp transpose) {
6386 ArrayRef<int64_t> permutation = transpose.getPermutation();
6387 VectorType sourceType = transpose.getSourceVectorType();
6388 ArrayRef<int64_t> inShape = sourceType.getShape();
6389 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6390 auto isNonScalableUnitDim = [&](int64_t dim) {
6391 return inShape[dim] == 1 && !inDimIsScalable[dim];
6393 int64_t current = 0;
6394 for (
auto p : permutation) {
6395 if (!isNonScalableUnitDim(p)) {
6405OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6407 VectorType resultType =
getType();
6410 if (getSource().
getType() == resultType)
6414 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6415 setOperand(precedingShapeCast.getSource());
6420 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6421 if (isOrderPreserving(transpose)) {
6422 setOperand(transpose.getVector());
6430 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6431 if (bcastOp.getSourceType() == resultType)
6432 return bcastOp.getSource();
6436 if (
auto denseAttr =
6437 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6438 return denseAttr.reshape(
getType());
6441 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6454static VectorType trimTrailingOneDims(VectorType oldType) {
6455 ArrayRef<int64_t> oldShape = oldType.getShape();
6456 ArrayRef<int64_t> newShape = oldShape;
6458 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6459 ArrayRef<bool> newScalableDims = oldScalableDims;
6461 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6462 newShape = newShape.drop_back(1);
6463 newScalableDims = newScalableDims.drop_back(1);
6468 if (newShape.empty()) {
6469 newShape = oldShape.take_back();
6470 newScalableDims = oldScalableDims.take_back();
6473 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6488class ShapeCastCreateMaskFolderTrailingOneDim final
6489 :
public OpRewritePattern<ShapeCastOp> {
6493 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6494 PatternRewriter &rewriter)
const override {
6495 Value shapeOpSrc = shapeOp->getOperand(0);
6496 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6497 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6498 if (!createMaskOp && !constantMaskOp)
6501 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6502 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6504 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6505 if (newVecType != shapeOpResTy)
6508 auto numDimsToDrop =
6509 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6516 auto maskOperands = createMaskOp.getOperands();
6517 auto numMaskOperands = maskOperands.size();
6520 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6522 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6523 if (!constant || (constant.value() != 1))
6526 SmallVector<Value> newMaskOperands =
6527 maskOperands.drop_back(numDimsToDrop);
6534 if (constantMaskOp) {
6535 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6536 auto numMaskOperands = maskDimSizes.size();
6539 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6541 if (maskDimSizes[i] != 1)
6545 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6559class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6563 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6564 PatternRewriter &rewriter)
const override {
6566 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6570 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6571 bool srcIsScalar = !srcVectorType;
6579 if (srcVectorType) {
6580 if (srcVectorType.getNumElements() ==
6581 shapeCastOp.getResultVectorType().getNumElements()) {
6583 shapeCastOp, shapeCastOp.getResultVectorType(),
6584 broadcastOp.getSource());
6595 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6597 BroadcastableToResult::Success) {
6599 shapeCastOp, dstVectorType, broadcastOp.getSource());
6608void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6609 MLIRContext *context) {
6611 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6619LogicalResult BitCastOp::verify() {
6620 auto sourceVectorType = getSourceVectorType();
6621 auto resultVectorType = getResultVectorType();
6623 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6624 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6625 return emitOpError(
"dimension size mismatch at: ") << i;
6628 DataLayout dataLayout = DataLayout::closest(*
this);
6629 auto sourceElementBits =
6631 auto resultElementBits =
6634 if (sourceVectorType.getRank() == 0) {
6635 if (sourceElementBits != resultElementBits)
6636 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6637 "types must be equal");
6638 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6639 resultElementBits * resultVectorType.getShape().back()) {
6641 "source/result bitwidth of the minor 1-D vectors must be equal");
6647OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6653 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6654 if (getResult().
getType() == otherOp.getSource().getType())
6655 return otherOp.getSource();
6657 setOperand(otherOp.getSource());
6661 Attribute sourceConstant = adaptor.getSource();
6662 if (!sourceConstant)
6665 Type srcElemType = getSourceVectorType().getElementType();
6666 Type dstElemType = getResultVectorType().getElementType();
6668 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6669 if (floatPack.isSplat()) {
6670 auto splat = floatPack.getSplatValue<FloatAttr>();
6673 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6674 uint32_t bits =
static_cast<uint32_t
>(
6675 splat.getValue().bitcastToAPInt().getZExtValue());
6677 bits = (bits << 16) | (bits & 0xffff);
6678 APInt intBits(32, bits);
6679 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6685 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6686 if (intPack.isSplat()) {
6687 auto splat = intPack.getSplatValue<IntegerAttr>();
6689 if (llvm::isa<IntegerType>(dstElemType)) {
6694 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6695 APInt intBits = splat.getValue().zext(dstBitWidth);
6698 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6699 intBits = (intBits << srcBitWidth) | intBits;
6713static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6714 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6715 SmallVector<int64_t, 8> res(memRefType.getShape());
6717 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6723void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
6725 result.addOperands(source);
6726 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6727 VectorType vectorType =
6728 VectorType::get(extractShape(memRefType),
6730 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6731 memRefType.getMemorySpace()));
6734LogicalResult TypeCastOp::verify() {
6735 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6736 if (!canonicalType.getLayout().isIdentity())
6737 return emitOpError(
"expects operand to be a memref with identity layout");
6738 if (!getResultMemRefType().getLayout().isIdentity())
6739 return emitOpError(
"expects result to be a memref with identity layout");
6740 if (getResultMemRefType().getMemorySpace() !=
6742 return emitOpError(
"expects result in same memory space");
6745 auto resultType = getResultMemRefType();
6749 "expects result and operand with same underlying scalar type: ")
6751 if (extractShape(sourceType) != extractShape(resultType))
6753 "expects concatenated result and operand shapes to be equal: ")
6762void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
6763 Value vector, ArrayRef<int64_t> permutation) {
6764 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6765 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6766 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6767 for (
unsigned i = 0; i < permutation.size(); ++i) {
6768 transposedShape[i] = vt.getShape()[permutation[i]];
6769 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6772 result.addOperands(vector);
6773 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6774 transposedScalableDims));
6775 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
6779OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6782 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6783 return splat.reshape(getResultVectorType());
6786 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6800 if (getSourceVectorType() == getResultVectorType() &&
6801 isOrderPreserving(*
this))
6807LogicalResult vector::TransposeOp::verify() {
6808 VectorType vectorType = getSourceVectorType();
6809 VectorType resultType = getResultVectorType();
6810 int64_t rank = resultType.getRank();
6811 if (vectorType.getRank() != rank)
6812 return emitOpError(
"vector result rank mismatch: ") << rank;
6814 ArrayRef<int64_t> perm = getPermutation();
6815 int64_t size = perm.size();
6817 return emitOpError(
"transposition length mismatch: ") << size;
6818 SmallVector<bool, 8> seen(rank,
false);
6819 for (
const auto &ta : llvm::enumerate(perm)) {
6820 if (ta.value() < 0 || ta.value() >= rank)
6821 return emitOpError(
"transposition index out of range: ") << ta.value();
6822 if (seen[ta.value()])
6823 return emitOpError(
"duplicate position index: ") << ta.value();
6824 seen[ta.value()] =
true;
6825 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6826 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6831std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6832 return llvm::to_vector<4>(getResultVectorType().
getShape());
6835void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6837 setResultRanges(getResult(), argRanges.front());
6843class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6847 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6848 PatternRewriter &rewriter)
const override {
6850 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6851 ArrayRef<int64_t> permutation2) {
6852 SmallVector<int64_t, 4>
result;
6853 for (
auto index : permutation2)
6854 result.push_back(permutation1[index]);
6859 vector::TransposeOp parentTransposeOp =
6860 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6861 if (!parentTransposeOp)
6864 SmallVector<int64_t, 4> permutation = composePermutations(
6865 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6868 transposeOp, transposeOp.getResult().
getType(),
6869 parentTransposeOp.getVector(), permutation);
6875class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
6879 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6880 PatternRewriter &rewriter)
const override {
6881 Value splat = getScalarSplatSource(transposeOp.getVector());
6886 transposeOp, transposeOp.getResultVectorType(), splat);
6892class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6896 LogicalResult matchAndRewrite(TransposeOp transpOp,
6897 PatternRewriter &rewriter)
const override {
6898 Value transposeSrc = transpOp.getVector();
6899 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6900 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6901 if (!createMaskOp && !constantMaskOp)
6906 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6909 auto maskOperands = createMaskOp.getOperands();
6910 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6914 transpOp, transpOp.getResultVectorType(), newOperands);
6919 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6923 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6929class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6933 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6934 PatternRewriter &rewriter)
const override {
6936 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6939 if (!isOrderPreserving(transposeOp))
6942 VectorType resultType = transposeOp.getType();
6949 shapeCastOp.getSource());
6968class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
6971 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6972 PatternRewriter &rewriter)
const override {
6973 auto fromElementsOp =
6974 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6975 if (!fromElementsOp)
6978 VectorType srcTy = fromElementsOp.getDest().getType();
6979 VectorType dstTy = transposeOp.getType();
6981 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6982 int64_t rank = srcTy.getRank();
6985 SmallVector<int64_t> inversePerm(rank, 0);
6986 for (int64_t i = 0; i < rank; ++i)
6987 inversePerm[permutation[i]] = i;
6989 ArrayRef<int64_t> srcShape = srcTy.getShape();
6990 ArrayRef<int64_t> dstShape = dstTy.getShape();
6991 SmallVector<int64_t> srcIdx(rank, 0);
6992 SmallVector<int64_t> dstIdx(rank, 0);
6996 auto elementsOld = fromElementsOp.getElements();
6997 SmallVector<Value> elementsNew;
6998 int64_t dstNumElements = dstTy.getNumElements();
6999 elementsNew.reserve(dstNumElements);
7003 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7007 for (int64_t j = 0; j < rank; ++j)
7008 srcIdx[j] = dstIdx[inversePerm[j]];
7010 int64_t srcLin =
linearize(srcIdx, srcStrides);
7012 elementsNew.push_back(elementsOld[srcLin]);
7046class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7049 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7050 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7052 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7053 PatternRewriter &rewriter)
const override {
7059 "not preceded by a broadcast");
7062 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7063 VectorType outputType = transpose.getResultVectorType();
7066 bool inputIsScalar = !inputType;
7067 if (inputIsScalar) {
7073 ArrayRef<int64_t> permutation = transpose.getPermutation();
7074 ArrayRef<int64_t> inputShape = inputType.getShape();
7075 int64_t inputRank = inputType.getRank();
7076 int64_t outputRank = transpose.getType().getRank();
7077 int64_t deltaRank = outputRank - inputRank;
7080 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7081 bool notOne = inputShape[inputIndex] != 1;
7082 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7083 bool groupEndFound = notOne || prevNotOne;
7084 if (groupEndFound) {
7085 int high = inputIndex + deltaRank;
7089 for (
int i = low; i < high; ++i) {
7090 if (permutation[i] < low || permutation[i] >= high) {
7092 transpose,
"permutation not local to group");
7106 vector::BroadcastableToResult::Success &&
7107 "not broadcastable directly to transpose output");
7118void vector::TransposeOp::getCanonicalizationPatterns(
7119 RewritePatternSet &results, MLIRContext *context) {
7120 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7121 FoldTransposeSplat, FoldTransposeFromElements,
7122 FoldTransposeBroadcast>(context);
7129void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7131 assert(kind == ConstantMaskKind::AllTrue ||
7132 kind == ConstantMaskKind::AllFalse);
7133 build(builder,
result, type,
7134 kind == ConstantMaskKind::AllTrue
7136 : SmallVector<int64_t>(type.getRank(), 0));
7139LogicalResult ConstantMaskOp::verify() {
7140 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7142 if (resultType.getRank() == 0) {
7143 if (getMaskDimSizes().size() != 1)
7144 return emitError(
"array attr must have length 1 for 0-D vectors");
7145 auto dim = getMaskDimSizes()[0];
7146 if (dim != 0 && dim != 1)
7147 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7152 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7154 "must specify array attr of size equal vector result rank");
7157 auto resultShape = resultType.getShape();
7158 auto resultScalableDims = resultType.getScalableDims();
7159 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7160 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7161 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7163 "array attr of size out of bounds of vector result dimension size");
7164 if (resultScalableDims[index] && maskDimSize != 0 &&
7165 maskDimSize != resultShape[index])
7167 "only supports 'none set' or 'all set' scalable dimensions");
7171 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7172 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7173 if (anyZeros && !allZeros)
7174 return emitOpError(
"expected all mask dim sizes to be zeros, "
7175 "as a result of conjunction with zero mask dim");
7179bool ConstantMaskOp::isAllOnesMask() {
7182 if (resultType.getRank() == 0) {
7183 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7184 return getMaskDimSizes()[0] == 1;
7186 for (
const auto [resultSize, maskDimSize] :
7187 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7188 if (maskDimSize < resultSize)
7194OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7195 ArrayRef<int64_t> bounds = getMaskDimSizes();
7198 auto createBoolSplat = [&](
bool x) {
7204 if (vectorSizes.empty()) {
7205 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7206 return createBoolSplat(bounds[0] == 1);
7209 if (bounds == vectorSizes)
7210 return createBoolSplat(
true);
7211 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7212 return createBoolSplat(
false);
7213 return OpFoldResult();
7220void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7222 ArrayRef<OpFoldResult> mixedOperands) {
7223 SmallVector<Value> operands =
7225 build(builder,
result, type, operands);
7228LogicalResult CreateMaskOp::verify() {
7229 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7231 if (vectorType.getRank() == 0) {
7232 if (getNumOperands() != 1)
7234 "must specify exactly one operand for 0-D create_mask");
7235 }
else if (getNumOperands() !=
7236 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7238 "must specify an operand for each result vector dimension");
7268class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7272 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7273 PatternRewriter &rewriter)
const override {
7274 VectorType maskType = createMaskOp.getVectorType();
7275 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7276 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7279 constexpr std::array<int64_t, 1> rankZeroShape{1};
7280 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7281 if (maskType.getRank() == 0) {
7282 maskTypeDimSizes = rankZeroShape;
7283 maskTypeDimScalableFlags = rankZeroScalableDims;
7288 SmallVector<int64_t, 4> constantDims;
7289 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7294 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7296 constantDims.push_back(*intSize);
7300 if (vscaleMultiplier < maskTypeDimSizes[i])
7302 constantDims.push_back(*vscaleMultiplier);
7309 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7310 value = std::clamp<int64_t>(value, 0, maskDimSize);
7313 if (llvm::is_contained(constantDims, 0))
7314 constantDims.assign(constantDims.size(), 0);
7325void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7326 MLIRContext *context) {
7327 results.
add<CreateMaskFolder>(context);
7335 OpBuilder &builder, OperationState &
result, Value mask,
7336 Operation *maskableOp,
7337 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7338 assert(maskRegionBuilder &&
7339 "builder callback for 'maskRegion' must be present");
7341 result.addOperands(mask);
7342 OpBuilder::InsertionGuard guard(builder);
7343 Region *maskRegion =
result.addRegion();
7345 maskRegionBuilder(builder, maskableOp);
7350 Value mask, Operation *maskableOp,
7351 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7352 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7358 Value mask, Value passthru, Operation *maskableOp,
7359 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7360 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7362 result.addOperands(passthru);
7363 result.addTypes(resultTypes);
7366ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7368 result.regions.reserve(1);
7369 Region &maskRegion = *
result.addRegion();
7374 OpAsmParser::UnresolvedOperand mask;
7379 OpAsmParser::UnresolvedOperand passthru;
7381 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7388 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7399 SmallVector<Type> resultTypes;
7402 result.types.append(resultTypes);
7408 if (parsePassthru.succeeded()) {
7409 if (resultTypes.empty())
7412 "expects a result if passthru operand is provided");
7421void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7422 p <<
" " << getMask();
7424 p <<
", " << getPassthru();
7428 Block *singleBlock = &getMaskRegion().getBlocks().front();
7435 p <<
" : " << getMask().getType();
7436 if (getNumResults() > 0)
7437 p <<
" -> " << getResultTypes();
7440void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7443 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7444 MaskOp>::ensureTerminator(region, builder, loc);
7450 if (isa<vector::YieldOp>(block.
back()))
7458 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7459 MaskOp>::ensureTerminator(region, builder, loc);
7465 Operation *maskedOp = &block.
front();
7466 opBuilder.setInsertionPointToEnd(&block);
7467 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7470LogicalResult MaskOp::verify() {
7472 Block &block = getMaskRegion().getBlocks().
front();
7474 return emitOpError(
"expects a terminator within the mask region");
7477 if (numMaskRegionOps > 2)
7478 return emitOpError(
"expects only one operation to mask");
7481 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7483 return emitOpError(
"expects a terminator within the mask region");
7485 if (terminator->getNumOperands() != getNumResults())
7487 "expects number of results to match mask region yielded values");
7490 if (numMaskRegionOps == 1)
7493 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7495 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7499 return emitOpError(
"expects number of results to match maskable operation "
7500 "number of results");
7502 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7503 return emitOpError(
"expects all the results from the MaskableOpInterface "
7504 "to match all the values returned by the terminator");
7506 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7508 "expects result type to match maskable operation result type");
7511 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7512 return emitOpError(
"multiple vector results not supported");
7515 Type expectedMaskType = maskableOp.getExpectedMaskType();
7516 if (getMask().
getType() != expectedMaskType)
7518 << expectedMaskType <<
" mask for the maskable operation";
7521 Value passthru = getPassthru();
7523 if (!maskableOp.supportsPassthru())
7525 "doesn't expect a passthru argument for this maskable operation");
7528 return emitOpError(
"expects result when passthru argument is provided");
7531 return emitOpError(
"expects passthru type to match result type");
7551static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7552 SmallVectorImpl<OpFoldResult> &results) {
7553 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7556 Block *block = maskOp.getMaskBlock();
7557 auto terminator = cast<vector::YieldOp>(block->
front());
7558 if (terminator.getNumOperands() == 0) {
7564 llvm::append_range(results, terminator.getOperands());
7568LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7569 SmallVectorImpl<OpFoldResult> &results) {
7570 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7578 Operation *maskableOp = getMaskableOp();
7582 llvm::append_range(results, maskableOp->
getResults());
7598class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
7601 LogicalResult matchAndRewrite(MaskOp maskOp,
7602 PatternRewriter &rewriter)
const override {
7603 if (!maskOp.isEmpty())
7606 if (!maskOp.hasPassthru())
7609 Block *block = maskOp.getMaskBlock();
7610 auto terminator = cast<vector::YieldOp>(block->
front());
7611 assert(terminator.getNumOperands() == 1 &&
7612 "expected one result when passthru is provided");
7615 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7616 terminator.getOperand(0), maskOp.getPassthru());
7622void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7623 MLIRContext *context) {
7624 results.
add<CanonializeEmptyMaskOp>(context);
7630Operation *MaskOp::getMaskableOp() {
7631 Block *block = getMaskBlock();
7635 return &block->
front();
7639bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
7645LogicalResult ScanOp::verify() {
7646 VectorType srcType = getSourceType();
7647 VectorType initialType = getInitialValueType();
7649 int64_t srcRank = srcType.getRank();
7650 int64_t reductionDim = getReductionDim();
7651 if (reductionDim >= srcRank)
7653 << reductionDim <<
" has to be less than " << srcRank;
7656 int64_t initialValueRank = initialType.getRank();
7657 if (initialValueRank != srcRank - 1)
7659 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7662 ArrayRef<int64_t> srcShape = srcType.getShape();
7663 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7664 SmallVector<int64_t> expectedShape;
7665 for (
int i = 0; i < srcRank; i++) {
7666 if (i != reductionDim)
7667 expectedShape.push_back(srcShape[i]);
7669 if (!llvm::equal(initialValueShapes, expectedShape)) {
7670 return emitOpError(
"incompatible input/initial value shapes");
7674 Type eltType = getDestType().getElementType();
7677 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7684 RewritePatternSet &
patterns, PatternBenefit benefit) {
7686 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7687 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7688 StridedSliceConstantMaskFolder, TransposeFolder>(
7693 CombiningKind kind, Value v1, Value acc,
7694 arith::FastMathFlagsAttr fastmath,
7701 case CombiningKind::ADD:
7703 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
7704 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7705 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7707 llvm_unreachable(
"invalid value types for ADD reduction");
7709 case CombiningKind::AND:
7711 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
7713 case CombiningKind::MAXNUMF:
7714 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7715 "expected float values");
7716 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7718 case CombiningKind::MAXIMUMF:
7719 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7720 "expected float values");
7721 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7723 case CombiningKind::MINNUMF:
7724 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7725 "expected float values");
7726 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7728 case CombiningKind::MINIMUMF:
7729 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7730 "expected float values");
7731 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7733 case CombiningKind::MAXSI:
7735 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7737 case CombiningKind::MINSI:
7739 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7741 case CombiningKind::MAXUI:
7743 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7745 case CombiningKind::MINUI:
7747 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7749 case CombiningKind::MUL:
7751 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
7752 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7753 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7755 llvm_unreachable(
"invalid value types for MUL reduction");
7757 case CombiningKind::OR:
7759 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
7761 case CombiningKind::XOR:
7763 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7767 assert(
result &&
"unknown CombiningKind");
7775void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7777 auto resultType = cast<VectorType>(
getType());
7778 if (resultType.isScalable()) {
7782 APInt zero(bitwidth, 0);
7783 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7784 ConstantIntRanges
result = {zero, high, zero, high};
7785 setResultRanges(getResult(),
result);
7815struct StepCompareFolder :
public OpRewritePattern<StepOp> {
7818 LogicalResult matchAndRewrite(StepOp stepOp,
7819 PatternRewriter &rewriter)
const override {
7820 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7822 for (OpOperand &use : stepOp.getResult().getUses()) {
7823 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7828 const unsigned stepOperandNumber = use.getOperandNumber();
7829 if (stepOperandNumber != 0)
7833 unsigned constOperandNumber = 1;
7834 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7835 std::optional<int64_t> maybeConstValue =
7837 if (!maybeConstValue.has_value())
7840 int64_t constValue = maybeConstValue.value();
7841 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7843 auto maybeSplat = [&]() -> std::optional<bool> {
7845 if ((pred == arith::CmpIPredicate::ult ||
7846 pred == arith::CmpIPredicate::uge) &&
7847 stepSize <= constValue)
7848 return pred == arith::CmpIPredicate::ult;
7851 if ((pred == arith::CmpIPredicate::ule ||
7852 pred == arith::CmpIPredicate::ugt) &&
7853 stepSize - 1 <= constValue) {
7854 return pred == arith::CmpIPredicate::ule;
7858 if ((pred == arith::CmpIPredicate::eq ||
7859 pred == arith::CmpIPredicate::ne) &&
7860 stepSize <= constValue)
7861 return pred == arith::CmpIPredicate::ne;
7863 return std::nullopt;
7866 if (!maybeSplat.has_value())
7871 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7876 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7888void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7889 MLIRContext *context) {
7890 results.
add<StepCompareFolder>(context);
7900 Operation *maskableOp) {
7901 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7913 Operation *maskableOp, Value mask,
7918 return MaskOp::create(builder, maskableOp->
getLoc(),
7921 return MaskOp::create(builder, maskableOp->
getLoc(),
7934 Value newValue, Value passthru) {
7938 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
7939 mask, newValue, passthru);
7946#define GET_ATTRDEF_CLASSES
7947#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7949#define GET_OP_CLASSES
7950#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)