41#include "llvm/ADT/ArrayRef.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Casting.h"
52#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
54#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
75 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
77 for (
bool b : denseElts.getValues<
bool>())
80 else if (!
b && val <= 0)
94 auto shape = m.getType().getShape();
97 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
98 if (maskIdx < dimSize)
111 auto maskOperands = m.getOperands();
112 for (
Value operand : maskOperands) {
113 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
128 vector::YieldOp::create(builder, loc);
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
138 case CombiningKind::MINUI:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
190 VectorType vectorType) {
193 if (shapedType.getRank() == 0 &&
199 shapedType.getRank(),
201 shapedType.getContext());
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
255 if (transferA.getVectorType() != transferB.getVectorType())
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
264 if (i < rankOffset) {
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
272 if (testDynamicValueUsingBounds) {
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
295 if (testDynamicValueUsingBounds) {
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (std::abs(computeDelta.value()) >= vectorDim)
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
321 testDynamicValueUsingBounds);
331 for (
auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
334 if (posInDim < dimSize + offsetInDim)
338 posInDim = offsetInDim;
348 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
350 assert(constOp &&
"Unexpected non-constant index");
351 return constOp.value();
361 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
373 llvm::transform(foldResults, std::back_inserter(values),
375 if (
auto attr = dyn_cast<Attribute>(foldResult))
377 builder, loc, cast<IntegerAttr>(attr).getInt())
380 return cast<Value>(foldResult);
393 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
395 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
405 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
407 if (intAttr.getType() != expectedType)
408 return IntegerAttr::get(expectedType, intAttr.getInt());
414 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415 auto intType = dyn_cast<IntegerType>(expectedType);
419 APFloat floatVal = floatAttr.getValue();
420 APInt intVal = floatVal.bitcastToAPInt();
421 return IntegerAttr::get(expectedType, intVal);
470void VectorDialect::initialize() {
472#define GET_ATTRDEF_LIST
473#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
478#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
481 addInterfaces<VectorInlinerInterface>();
483 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
486 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
488 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
498 if (isa<ub::PoisonAttrInterface>(value))
501 return arith::ConstantOp::materialize(builder, value, type, loc);
517void vector::MultiDimReductionOp::build(
OpBuilder &builder,
520 CombiningKind kind) {
522 for (
const auto &en : llvm::enumerate(reductionMask))
524 reductionDims.push_back(en.index());
525 build(builder,
result, kind, source,
acc, reductionDims);
528OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
530 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
535std::optional<SmallVector<int64_t, 4>>
536MultiDimReductionOp::getShapeForUnroll() {
537 return llvm::to_vector<4>(getSourceVectorType().
getShape());
540LogicalResult MultiDimReductionOp::verify() {
543 Type inferredReturnType;
544 auto sourceScalableDims = getSourceVectorType().getScalableDims();
545 for (
auto [dimIdx, dimSize] :
546 llvm::enumerate(getSourceVectorType().
getShape()))
547 if (!llvm::any_of(getReductionDims(),
548 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
549 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
551 targetShape.push_back(dimSize);
552 scalableDims.push_back(sourceScalableDims[dimIdx]);
555 if (targetShape.empty())
556 inferredReturnType = getSourceVectorType().getElementType();
558 inferredReturnType = VectorType::get(
559 targetShape, getSourceVectorType().
getElementType(), scalableDims);
560 if (
getType() != inferredReturnType)
562 <<
" is incompatible with source type "
563 << getSourceVectorType();
569Type MultiDimReductionOp::getExpectedMaskType() {
570 auto vecType = getSourceVectorType();
571 return VectorType::get(vecType.getShape(),
572 IntegerType::get(vecType.getContext(), 1),
573 vecType.getScalableDims());
582struct ElideUnitDimsInMultiDimReduction
586 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
587 PatternRewriter &rewriter)
const override {
588 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
589 for (
const auto &dim :
enumerate(shape)) {
590 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
595 OpBuilder::InsertionGuard guard(rewriter);
598 if (reductionOp.isMasked()) {
600 rootOp = reductionOp.getMaskingOp();
601 mask = reductionOp.getMaskingOp().getMask();
603 rootOp = reductionOp;
606 Location loc = reductionOp.getLoc();
607 Value acc = reductionOp.getAcc();
609 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
611 VectorType newMaskType =
612 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
613 dstVecType.getScalableDims());
614 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
616 cast = vector::ShapeCastOp::create(
617 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
622 mask = vector::ExtractOp::create(rewriter, loc, mask);
623 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
628 cast,
nullptr, mask);
635void MultiDimReductionOp::getCanonicalizationPatterns(
637 results.
add<ElideUnitDimsInMultiDimReduction>(context);
646 arith::FastMathFlags fastMathFlags) {
652 arith::FastMathFlags fastMathFlags) {
654 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
658LogicalResult ReductionOp::verify() {
660 int64_t rank = getSourceVectorType().getRank();
662 return emitOpError(
"unsupported reduction rank: ") << rank;
665 Type eltType = getDest().getType();
668 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
677Type ReductionOp::getExpectedMaskType() {
678 auto vecType = getSourceVectorType();
679 return VectorType::get(vecType.getShape(),
680 IntegerType::get(vecType.getContext(), 1),
681 vecType.getScalableDims());
688 case arith::AtomicRMWKind::addf:
689 case arith::AtomicRMWKind::addi:
690 return vector::ReductionOp::create(builder,
vector.getLoc(),
691 CombiningKind::ADD,
vector);
692 case arith::AtomicRMWKind::mulf:
693 case arith::AtomicRMWKind::muli:
694 return vector::ReductionOp::create(builder,
vector.getLoc(),
695 CombiningKind::MUL,
vector);
696 case arith::AtomicRMWKind::minimumf:
697 return vector::ReductionOp::create(builder,
vector.getLoc(),
698 CombiningKind::MINIMUMF,
vector);
699 case arith::AtomicRMWKind::mins:
700 return vector::ReductionOp::create(builder,
vector.getLoc(),
701 CombiningKind::MINSI,
vector);
702 case arith::AtomicRMWKind::minu:
703 return vector::ReductionOp::create(builder,
vector.getLoc(),
704 CombiningKind::MINUI,
vector);
705 case arith::AtomicRMWKind::maximumf:
706 return vector::ReductionOp::create(builder,
vector.getLoc(),
707 CombiningKind::MAXIMUMF,
vector);
708 case arith::AtomicRMWKind::maxs:
709 return vector::ReductionOp::create(builder,
vector.getLoc(),
710 CombiningKind::MAXSI,
vector);
711 case arith::AtomicRMWKind::maxu:
712 return vector::ReductionOp::create(builder,
vector.getLoc(),
713 CombiningKind::MAXUI,
vector);
714 case arith::AtomicRMWKind::andi:
715 return vector::ReductionOp::create(builder,
vector.getLoc(),
716 CombiningKind::AND,
vector);
717 case arith::AtomicRMWKind::ori:
718 return vector::ReductionOp::create(builder,
vector.getLoc(),
719 CombiningKind::OR,
vector);
720 case arith::AtomicRMWKind::minnumf:
721 return vector::ReductionOp::create(builder,
vector.getLoc(),
722 CombiningKind::MINNUMF,
vector);
723 case arith::AtomicRMWKind::maxnumf:
724 return vector::ReductionOp::create(builder,
vector.getLoc(),
725 CombiningKind::MAXNUMF,
vector);
726 case arith::AtomicRMWKind::xori:
727 return vector::ReductionOp::create(builder,
vector.getLoc(),
728 CombiningKind::XOR,
vector);
736std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
737 return llvm::to_vector<4>(getSourceVectorType().
getShape());
744 LogicalResult matchAndRewrite(ReductionOp reductionOp,
749 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
752 if (maskableOp.isMasked()) {
754 rootOp = maskableOp.getMaskingOp();
755 mask = maskableOp.getMaskingOp().getMask();
757 rootOp = reductionOp;
760 auto vectorType = reductionOp.getSourceVectorType();
761 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
764 Location loc = reductionOp.getLoc();
766 mask = ExtractOp::create(rewriter, loc, mask);
767 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
769 if (Value acc = reductionOp.getAcc())
772 reductionOp.getFastmathAttr(), mask);
782 results.
add<ElideSingleElementReduction>(context);
796 getIndexingMapsAttrName(
result.name),
800 getIteratorTypesAttrName(
result.name),
803 return IteratorTypeAttr::get(builder.getContext(), t);
812 ContractionOp::getDefaultKind());
818 ArrayAttr iteratorTypes, CombiningKind kind) {
821 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
822 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
824 CombiningKindAttr::get(builder.
getContext(), kind));
835 DictionaryAttr dictAttr;
849 result.attributes.append(dictAttr.getValue().begin(),
850 dictAttr.getValue().end());
856 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
857 result.attributes.get(getIteratorTypesAttrName(
result.name)));
858 if (!iteratorTypes) {
860 <<
"expected " << getIteratorTypesAttrName(
result.name)
861 <<
" array attribute";
866 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
867 auto maybeIteratorType = symbolizeIteratorType(s);
868 if (!maybeIteratorType.has_value())
869 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
871 iteratorTypeAttrs.push_back(
872 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
874 result.attributes.set(getIteratorTypesAttrName(
result.name),
877 if (!
result.attributes.get(getKindAttrName(
result.name))) {
879 getKindAttrName(
result.name),
880 CombiningKindAttr::get(
result.getContext(),
881 ContractionOp::getDefaultKind()));
883 if (masksInfo.empty())
885 if (masksInfo.size() != 2)
887 "expected zero or exactly 2 vector mask operands");
888 auto lhsType = llvm::cast<VectorType>(types[0]);
889 auto rhsType = llvm::cast<VectorType>(types[1]);
891 std::array<VectorType, 2> maskTypes = {
901 auto attrNames = getTraitAttrNames();
903 traitAttrsSet.insert_range(attrNames);
905 for (
auto attr : (*this)->getAttrs()) {
906 if (attr.getName() == getIteratorTypesAttrName()) {
908 llvm::cast<ArrayAttr>(attr.getValue())
909 .getAsValueRange<IteratorTypeAttr, IteratorType>();
915 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
916 return StringAttr::get(
getContext(), stringifyIteratorType(t));
919 attrs.emplace_back(getIteratorTypesAttrName(),
920 ArrayAttr::get(
getContext(), iteratorTypeNames));
921 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
922 attrs.push_back(attr);
925 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
926 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
927 p << getRhs() <<
", " << getAcc();
930 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
935 const std::vector<std::pair<int64_t, int64_t>> &map) {
936 for (
auto &dimPair : map) {
937 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
938 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
939 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
946 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
948 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
949 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
952 for (
auto &dimPair : contractingDimMap) {
953 lhsContractingDimSet.insert(dimPair.first);
954 rhsContractingDimSet.insert(dimPair.second);
957 llvm::make_second_range(batchDimMap));
961 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
962 if (lhsContractingDimSet.count(i) > 0)
964 expectedResultDims.push_back(lhsType.getDimSize(i));
968 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
969 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
971 expectedResultDims.push_back(rhsType.getDimSize(i));
975 if (expectedResultDims.empty()) {
977 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
978 return op.emitOpError(
"invalid accumulator/result vector shape");
981 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
982 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
983 if (!resVectorType || !accVectorType)
984 return op.emitOpError(
"invalid accumulator/result vector shape");
990 AffineMap lhsMap = op.getIndexingMapsArray()[0];
991 AffineMap rhsMap = op.getIndexingMapsArray()[1];
993 return op.emitOpError(
994 "expected all dimensions to be either a LHS or a RHS dimension");
997 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
998 VectorType v = pair.first;
999 auto map = pair.second;
1000 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1001 unsigned pos = map.getDimPosition(idx);
1006 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1007 return op.emitOpError(
"expected all dimensions to get an extent as "
1008 "either a LHS or a RHS dimension");
1010 AffineMap resMap = op.getIndexingMapsArray()[2];
1015 assert(llvm::all_of(expectedMap.
getResults(),
1016 llvm::IsaPred<AffineConstantExpr>) &&
1017 "expected constant extent along all dimensions.");
1019 auto expectedShape = llvm::to_vector<4>(
1021 return cast<AffineConstantExpr>(e).getValue();
1024 VectorType::get(expectedShape, resVectorType.getElementType(),
1025 resVectorType.getScalableDims());
1026 if (resVectorType != expected || accVectorType != expected)
1027 return op.emitOpError(
1028 "invalid accumulator/result vector shape, expected: ")
1034LogicalResult ContractionOp::verify() {
1035 VectorType lhsType = getLhsType();
1036 VectorType rhsType = getRhsType();
1037 Type accType = getAccType();
1038 Type resType = getResultType();
1040 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1041 if (!lhsType.getElementType().isSignlessInteger())
1042 return emitOpError(
"only supports signless integer types");
1046 if (getIndexingMapsArray().size() != 3)
1047 return emitOpError(
"expected an indexing map for each vector operand");
1052 unsigned numIterators = getIteratorTypes().getValue().size();
1053 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1054 auto index = it.index();
1055 auto map = it.value();
1056 if (map.getNumSymbols() != 0)
1058 <<
index <<
" to have no symbols";
1059 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1060 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1063 if (map.getNumDims() != numIterators)
1065 <<
index <<
" to have " << numIterators <<
" number of inputs";
1066 if (map.getNumResults() != rank)
1068 <<
index <<
" to have " << rank <<
" number of outputs";
1069 if (!map.isProjectedPermutation())
1071 <<
index <<
" to be a projected permutation of its inputs";
1074 auto contractingDimMap = getContractingDimMap();
1075 auto batchDimMap = getBatchDimMap();
1078 if (contractingDimMap.empty())
1079 return emitOpError(
"expected at least one contracting dimension pair");
1082 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1083 return emitOpError(
"invalid contracting dimension map");
1087 return emitOpError(
"invalid batch dimension map");
1091 contractingDimMap, batchDimMap)))
1095 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1096 auto elementType = vectorType ? vectorType.getElementType() : resType;
1098 return emitOpError(
"unsupported contraction type");
1101 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1108Type ContractionOp::getExpectedMaskType() {
1109 auto indexingMaps = this->getIndexingMapsArray();
1112 VectorType lhsType = this->getLhsType();
1113 VectorType rhsType = this->getRhsType();
1115 unsigned numVecDims = lhsIdxMap.
getNumDims();
1121 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1124 lhsType.getScalableDims()[dimIdx];
1126 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1129 rhsType.getScalableDims()[dimIdx];
1132 assert(ShapedType::isStaticShape(maskShape) &&
1133 "Mask shape couldn't be computed");
1135 return VectorType::get(maskShape,
1136 IntegerType::get(lhsType.getContext(), 1),
1137 maskShapeScalableDims);
1142 getIteratorTypesAttrName(), getKindAttrName()};
1152static std::vector<std::pair<int64_t, int64_t>>
1154 IteratorType targetIteratorType,
MLIRContext *context) {
1155 std::vector<std::pair<int64_t, int64_t>> dimMap;
1156 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1157 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1158 if (iteratorType != targetIteratorType)
1164 if (lhsDim >= 0 && rhsDim >= 0)
1165 dimMap.emplace_back(lhsDim, rhsDim);
1170void ContractionOp::getIterationBounds(
1172 auto lhsShape = getLhsType().getShape();
1173 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1175 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1178 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1179 if (iteratorType == IteratorType::reduction) {
1182 assert(lhsDimIndex >= 0);
1183 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1188 assert(resDimIndex >= 0);
1189 assert(resVectorType !=
nullptr);
1190 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1194void ContractionOp::getIterationIndexMap(
1196 unsigned numMaps = getIndexingMapsArray().size();
1197 iterationIndexMap.resize(numMaps);
1198 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1199 auto index = it.index();
1200 auto map = it.value();
1201 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1202 auto dim = cast<AffineDimExpr>(map.getResult(i));
1203 iterationIndexMap[
index][dim.getPosition()] = i;
1208std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1210 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1214std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1216 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1220std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1222 getIterationBounds(
shape);
1244template <
typename AddOpType>
1250 auto canonicalize = [&](
Value maybeContraction,
1251 Value otherOperand) -> vector::ContractionOp {
1252 vector::ContractionOp contractionOp =
1253 dyn_cast_or_null<vector::ContractionOp>(
1256 return vector::ContractionOp();
1257 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1258 contractionOp.getAcc().getDefiningOp())) {
1259 if (maybeZero.getValue() ==
1260 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1262 bvm.
map(contractionOp.getAcc(), otherOperand);
1263 auto newContraction =
1264 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1265 rewriter.
replaceOp(addOp, newContraction.getResult());
1266 return newContraction;
1269 return vector::ContractionOp();
1272 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1273 vector::ContractionOp
contract = canonicalize(a,
b);
1298 setResultRanges(getResult(), argRanges.front());
1303 auto vectorTy = cast<VectorType>(source.
getType());
1328 build(builder,
result, source, dynamicPos,
1333ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1334 ExtractOp::Adaptor adaptor,
1336 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1337 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1338 vectorType.getRank()) {
1339 inferredReturnTypes.push_back(vectorType.getElementType());
1341 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1342 vectorType.getRank());
1343 inferredReturnTypes.push_back(VectorType::get(
1344 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1345 vectorType.getScalableDims().drop_front(n)));
1353 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1354 return vectorType && vectorType.getShape().equals({1}) &&
1355 vectorType.getElementType() == r.front();
1357 if (l.size() == 1 && r.size() == 1 &&
1358 (isCompatible(l, r) || isCompatible(r, l)))
1363LogicalResult vector::ExtractOp::verify() {
1364 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1365 if (resTy.getRank() == 0)
1367 "expected a scalar instead of a 0-d vector as the result type");
1370 auto dynamicMarkersCount =
1371 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1372 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1374 "mismatch between dynamic and static positions (kDynamic marker but no "
1375 "corresponding dynamic position) -- this can only happen due to an "
1376 "incorrect fold/rewrite");
1377 auto position = getMixedPosition();
1378 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1380 "expected position attribute of rank no greater than vector rank");
1381 for (
auto [idx, pos] : llvm::enumerate(position)) {
1382 if (
auto attr = dyn_cast<Attribute>(pos)) {
1383 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1385 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1386 return emitOpError(
"expected position attribute #")
1388 <<
" to be a non-negative integer smaller than the "
1389 "corresponding vector dimension or poison (-1)";
1396template <
typename IntType>
1398 return llvm::to_vector<4>(llvm::map_range(
1399 arrayAttr.getAsRange<IntegerAttr>(),
1400 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1406 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1410 if (extractOp.hasDynamicPosition())
1414 ExtractOp currentOp = extractOp;
1416 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1417 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1420 if (currentOp.hasDynamicPosition())
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1425 extractOp.setOperand(0, currentOp.getSource());
1428 std::reverse(globalPosition.begin(), globalPosition.end());
1429 extractOp.setStaticPosition(globalPosition);
1441class ExtractFromInsertTransposeChainState {
1443 ExtractFromInsertTransposeChainState(ExtractOp e);
1452 template <
typename ContainerA,
typename ContainerB>
1453 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1454 return a.size() <=
b.size() &&
1455 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1462 template <
typename ContainerA,
typename ContainerB>
1463 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1464 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1465 if (elemA < 0 || elemB < 0)
1476 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1480 void updateStateForNextIteration(Value v) {
1487 LogicalResult handleTransposeOp();
1490 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1505 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1510 Value tryToFoldExtractOpInPlace(Value source);
1512 ExtractOp extractOp;
1514 int64_t extractedRank;
1516 InsertOp nextInsertOp;
1517 TransposeOp nextTransposeOp;
1527 SmallVector<int64_t> sentinels;
1528 SmallVector<int64_t> extractPosition;
1532ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1534 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1535 extractedRank(extractOp.getNumIndices()) {
1536 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1537 sentinels.reserve(vectorRank - extractedRank);
1538 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1539 sentinels.push_back(-(i + 1));
1541 extractOp.getStaticPosition().end());
1547LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1549 if (extractOp.hasDynamicPosition())
1552 if (!nextTransposeOp)
1555 nextTransposeOp.getPermutation(), extractOp.getContext()));
1562ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1565 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1568 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1569 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1572 res = nextInsertOp.getValueToStore();
1581ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1583 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1586 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1596 res = nextInsertOp.getValueToStore();
1604Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1607 if (extractOp.hasDynamicPosition())
1611 bool nothingToFold = (source == extractOp.getSource());
1612 if (nothingToFold || !canFold())
1616 OpBuilder
b(extractOp.getContext());
1617 extractOp.setStaticPosition(
1619 extractOp.getSourceMutable().assign(source);
1620 return extractOp.getResult();
1624Value ExtractFromInsertTransposeChainState::fold() {
1626 if (extractOp.hasDynamicPosition())
1629 Value valueToExtractFrom = extractOp.getSource();
1630 updateStateForNextIteration(valueToExtractFrom);
1631 while (nextInsertOp || nextTransposeOp) {
1634 if (succeeded(handleTransposeOp())) {
1635 valueToExtractFrom = nextTransposeOp.getVector();
1636 updateStateForNextIteration(valueToExtractFrom);
1642 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1647 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1648 return tryToFoldExtractOpInPlace(
result);
1652 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1658 valueToExtractFrom = nextInsertOp.getDest();
1659 updateStateForNextIteration(valueToExtractFrom);
1662 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1667 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1668 auto vecType = dyn_cast<VectorType>(type);
1669 return vecType && vecType.getRank() == 0;
1679 if (isa<BroadcastOp>(op))
1682 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1690 VectorType srcType = shapeCast.getSourceVectorType();
1692 uint64_t srcRank = srcType.getRank();
1694 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1720 Operation *defOp = extractOp.getSource().getDefiningOp();
1727 if (extractOp.getType() == input.
getType())
1733 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1734 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1735 unsigned inputRank = inputType ? inputType.getRank() : 0;
1736 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1737 unsigned extractRank = extractType ? extractType.getRank() : 0;
1740 if (extractRank > inputRank)
1744 assert(inputType &&
"input must be a vector type because of previous checks");
1753 extractType.getShape() != inputShape.take_back(extractRank))
1758 unsigned deltaOverall = inputRank - extractRank;
1759 unsigned deltaBroadcast = broadcastRank - inputRank;
1763 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1764 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1767 extractOp->setOperands(
1768 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1769 extractOp.setStaticPosition(staticPos);
1770 return extractOp.getResult();
1786 if (extractOp.hasDynamicPosition())
1789 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1794 if (shuffleOp.getResultVectorType().getRank() != 1)
1797 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1798 auto shuffleMask = shuffleOp.getMask();
1799 int64_t extractIdx = extractOp.getStaticPosition()[0];
1800 int64_t shuffleIdx = shuffleMask[extractIdx];
1803 if (shuffleIdx < inputVecSize) {
1804 extractOp.setOperand(0, shuffleOp.getV1());
1805 extractOp.setStaticPosition({shuffleIdx});
1807 extractOp.setOperand(0, shuffleOp.getV2());
1808 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1811 return extractOp.getResult();
1817 if (extractOp.hasDynamicPosition())
1820 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1825 auto getDimReverse = [](VectorType type,
int64_t n) {
1826 return type.getShape().take_back(n + 1).front();
1829 llvm::isa<VectorType>(extractOp.getType())
1830 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1832 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1834 if (destinationRank > 0) {
1835 auto destinationType =
1836 llvm::cast<VectorType>(extractOp.getResult().getType());
1837 for (
int64_t i = 0; i < destinationRank; i++) {
1841 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1842 getDimReverse(destinationType, i))
1849 std::reverse(extractedPos.begin(), extractedPos.end());
1852 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1853 strides.push_back(stride);
1855 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1863 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1865 for (
int64_t i = 0; i < numDimension; i++) {
1866 newStrides.push_back(stride);
1868 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1870 std::reverse(newStrides.begin(), newStrides.end());
1874 extractOp.setStaticPosition(newPosition);
1875 extractOp.setOperand(0, shapeCastOp.getSource());
1876 return extractOp.getResult();
1882 if (extractOp.hasDynamicPosition())
1885 auto extractStridedSliceOp =
1886 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1887 if (!extractStridedSliceOp)
1896 if (extractStridedSliceOp.hasNonUnitStrides())
1902 while (!sliceOffsets.empty()) {
1903 size_t lastOffset = sliceOffsets.size() - 1;
1904 if (sliceOffsets.back() != 0 ||
1905 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1906 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1908 sliceOffsets.pop_back();
1910 unsigned destinationRank = 0;
1911 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1912 destinationRank = vecType.getRank();
1915 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1916 sliceOffsets.size())
1920 assert(extractedPos.size() >= sliceOffsets.size());
1921 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1922 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1923 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1927 extractOp.setStaticPosition(extractedPos);
1928 return extractOp.getResult();
1934 if (extractOp.hasDynamicPosition())
1938 llvm::isa<VectorType>(extractOp.getType())
1939 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1941 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1951 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1952 insertOp.getSourceVectorType().getRank();
1953 if (destinationRank > insertOp.getSourceVectorType().getRank())
1958 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1959 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1962 bool disjoint =
false;
1964 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1965 int64_t start = insertOffsets[dim];
1967 (dim < insertRankDiff)
1969 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1971 int64_t offset = extractOffsets[dim];
1973 if (start <= offset && offset < end) {
1974 if (dim >= insertRankDiff)
1975 offsetDiffs.push_back(offset - start);
1986 insertOp.getSourceVectorType().getRank() - destinationRank;
1987 for (
int64_t i = 0; i < destinationRank; i++) {
1988 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1989 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1993 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1996 extractOp.setStaticPosition(offsetDiffs);
1997 return extractOp.getResult();
2001 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2014 if (extractOp.hasDynamicPosition())
2018 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2019 if (!fromElementsOp)
2023 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2024 if (vecType.isScalable())
2028 int64_t rank = vecType.getRank();
2030 if (extractOp.getType() != vecType.getElementType())
2033 "unexpected number of indices");
2038 for (
int i = rank - 1; i >= 0; --i) {
2039 flatIndex +=
indices[i] * stride;
2040 stride *= vecType.getDimSize(i);
2042 return fromElementsOp.getElements()[flatIndex];
2047template <
typename OpType,
typename AdaptorType>
2050 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2051 OperandRange dynamicPosition = op.getDynamicPosition();
2054 if constexpr (std::is_same_v<OpType, ExtractOp>)
2055 vectorShape = op.getSourceVectorType().getShape();
2060 if (!dynamicPosition.size())
2067 bool opChange =
false;
2068 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2069 if (ShapedType::isStatic(staticPosition[i]))
2073 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2074 int64_t value = attr.getInt();
2078 staticPosition[i] = attr.getInt();
2083 operands.push_back(position);
2087 op.setStaticPosition(staticPosition);
2088 op.getOperation()->setOperands(operands);
2090 return op.getResult();
2100 if (!is_contained(staticPos, poisonVal))
2103 return ub::PoisonAttr::get(context);
2108 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2117 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2122 if (denseAttr.isSplat()) {
2124 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2129 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2130 if (vecTy.isScalable())
2133 if (extractOp.hasDynamicPosition()) {
2148 copy(extractOp.getStaticPosition(), completePositions.begin());
2151 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2154 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2156 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2159 newAttr = *denseValuesBegin;
2165OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2169 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2176 SmallVector<Value> operands = {getSource()};
2180 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2186 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2201 return inplaceFolded;
2207class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2211 LogicalResult matchAndRewrite(ExtractOp extractOp,
2212 PatternRewriter &rewriter)
const override {
2215 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2221 BroadcastableToResult::Success)
2230class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2234 LogicalResult matchAndRewrite(ExtractOp extractOp,
2235 PatternRewriter &rewriter)
const override {
2237 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2241 VectorType extractedMaskType =
2242 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2244 if (!extractedMaskType)
2247 auto maskOperands = createMaskOp.getOperands();
2248 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2249 VectorType maskType = createMaskOp.getVectorType();
2251 bool containsUnknownDims =
false;
2254 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2256 int64_t pos = extractOpPos[dimIdx];
2257 Value operand = maskOperands[dimIdx];
2258 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2261 containsUnknownDims =
true;
2265 int64_t createMaskBound =
2266 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2268 if (pos != ShapedType::kDynamic) {
2271 allFalse |= pos >= createMaskBound;
2272 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2276 containsUnknownDims =
true;
2283 }
else if (!containsUnknownDims) {
2285 extractOp, extractedMaskType,
2286 maskOperands.drop_front(extractOpPos.size()));
2296LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2297 PatternRewriter &rewriter) {
2298 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2302 VectorType sourceType = castOp.getSourceVectorType();
2303 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2307 if (sourceType.getNumElements() != targetType.getNumElements())
2311 castOp.getSource());
2321LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2322 PatternRewriter &rewriter) {
2324 if (extractOp.hasDynamicPosition())
2328 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2333 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2334 if (!fromElementsOp)
2336 VectorType inputType = fromElementsOp.getType();
2339 if (resultType.isScalable() || inputType.isScalable())
2344 SmallVector<int64_t> firstElementPos =
2345 llvm::to_vector(extractOp.getStaticPosition());
2346 firstElementPos.append(resultType.getRank(), 0);
2349 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2350 flatIndex += firstElementPos[i] * stride;
2351 stride *= inputType.getDimSize(i);
2356 extractOp, resultType,
2357 fromElementsOp.getElements().slice(flatIndex,
2358 resultType.getNumElements()));
2364void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2365 MLIRContext *context) {
2366 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2367 results.
add(foldExtractFromShapeCastToShapeCast);
2368 results.
add(foldExtractFromFromElements);
2373 for (
auto attr : arrayAttr)
2374 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2381std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2392 if (operands.empty())
2395 return llvm::all_of(operands, [&](
Value operand) {
2397 return currentDef == defOp;
2415 auto fromElementsOp =
2416 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2417 if (!fromElementsOp)
2420 llvm::append_range(results, fromElementsOp.getElements());
2437 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2441 if (isa<VectorType>(bcastOp.getSource().getType()))
2444 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2446 Value scalar = bcastOp.getSource();
2447 results.assign(resultVecType.getNumElements(), scalar);
2451LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2452 SmallVectorImpl<OpFoldResult> &results) {
2459ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2460 ToElementsOp::Adaptor adaptor,
2461 SmallVectorImpl<Type> &inferredReturnTypes) {
2462 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2463 Type elType = vecType.getElementType();
2464 inferredReturnTypes.append(vecType.getNumElements(), elType);
2486 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2491 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2495 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2500 int64_t dstRank = dstShape.size();
2501 int64_t srcRank = srcShape.size();
2504 auto srcElems = vector::ToElementsOp::create(
2505 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2507 int64_t dstCount = llvm::product_of(dstShape);
2510 replacements.reserve(dstCount);
2535 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2538 for (
int64_t k = 0; k < srcRank; ++k)
2539 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2542 replacements.push_back(srcElems.getResult(srcLin));
2545 rewriter.
replaceOp(toElementsOp, replacements);
2550void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2551 MLIRContext *context) {
2552 results.
add<ToElementsOfBroadcast>(context);
2572 OperandRange fromElemsOperands = fromElementsOp.getElements();
2573 if (fromElemsOperands.empty())
2576 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2584 Value toElementsInput = toElementsOp.getSource();
2585 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2586 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2587 return toElementsInput;
2607 if (llvm::any_of(elements, [](
Attribute attr) {
2608 return !attr || isa<ub::PoisonAttrInterface>(attr);
2613 auto destVecType = fromElementsOp.getDest().getType();
2614 auto destEltType = destVecType.getElementType();
2615 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2620 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2627OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2644 if (!llvm::all_equal(fromElementsOp.getElements()))
2647 fromElementsOp, fromElementsOp.getType(),
2648 fromElementsOp.getElements().front());
2676 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2680 if (fromElements.getType().getNumElements() == 1)
2691 for (
auto [insertIndex, element] :
2692 llvm::enumerate(fromElements.getElements())) {
2695 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2698 "element not from vector.extract");
2703 if (insertIndex == 0) {
2704 source = extractOp.getSource();
2705 }
else if (extractOp.getSource() != source) {
2707 "element from different vector");
2711 int64_t rank = position.size();
2712 assert(rank == source.getType().getRank() &&
2713 "scalar extract must have full rank position");
2724 if (insertIndex == 0) {
2725 const int64_t numElms = fromElements.getType().getNumElements();
2728 while (
index > 0 && position[
index - 1] == 0 &&
2729 numSuffixElms < numElms) {
2730 numSuffixElms *= source.getType().getDimSize(
index - 1);
2733 if (numSuffixElms != numElms) {
2735 fromElements,
"elements do not form a suffix of source");
2737 expectedPosition = llvm::to_vector(position);
2738 combinedPosition = position.drop_back(rank -
index);
2742 else if (expectedPosition != position) {
2744 fromElements,
"elements not in ascending order (static order)");
2746 increment(expectedPosition, source.getType().getShape());
2749 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2750 fromElements.getLoc(), source, combinedPosition);
2753 fromElements, fromElements.getType(), extracted);
2761 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2780void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2782 setResultRanges(getResult(), argRanges.front());
2785std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2786 return llvm::to_vector<4>(getResultVectorType().
getShape());
2791static llvm::SetVector<int64_t>
2794 int64_t rankDiff = dstShape.size() - srcShape.size();
2797 for (
auto [s1, s2] :
2798 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2800 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2808llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2810 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2813 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2829Value BroadcastOp::createOrFoldBroadcastOp(
2830 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
2831 const llvm::SetVector<int64_t> &broadcastedDims) {
2832 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2835 SmallVector<int64_t> checkShape;
2836 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2837 if (broadcastedDims.contains(i))
2839 checkShape.push_back(dstShape[i]);
2841 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2842 "ill-formed broadcastedDims contains values not confined to "
2845 Location loc = value.
getLoc();
2847 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2848 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2851 if (!srcVectorType) {
2852 assert(checkShape.empty() &&
2853 "ill-formed createOrFoldBroadcastOp arguments");
2854 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2857 assert(srcVectorType.getShape().equals(checkShape) &&
2858 "ill-formed createOrFoldBroadcastOp arguments");
2868 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2869 broadcastShape.reserve(dstShape.size());
2885 int64_t nextSrcShapeDim = broadcastedDims.size();
2886 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2887 if (broadcastedDims.contains(i)) {
2892 broadcastShape.push_back(dstShape[i]);
2893 permutation[i] = broadcastShape.size() - 1;
2899 permutation[i] = nextSrcShapeDim++;
2903 llvm::append_range(broadcastShape, srcVectorType.getShape());
2908 "unexpected \"dim-1\" broadcast");
2910 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2912 vector::BroadcastableToResult::Success &&
2913 "must be broadcastable");
2914 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2917 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2918 if (permutation[i] != i)
2919 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2925 Type srcType, VectorType dstVectorType,
2926 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2928 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2932 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2936 int64_t srcRank = srcVectorType.getRank();
2937 int64_t dstRank = dstVectorType.getRank();
2938 if (srcRank > dstRank)
2942 int64_t lead = dstRank - srcRank;
2943 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2946 bool foundMismatchingDims =
false;
2949 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2950 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2951 if (srcDim != 1 && srcDim != dstDim)
2952 foundMismatchingDims =
true;
2955 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2956 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2957 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2960 (srcDimScalableFlag != dstDimScalableFlag &&
2961 (srcDim != 1 || srcDimScalableFlag)))
2962 foundMismatchingDims =
true;
2964 if (foundMismatchingDims) {
2965 if (mismatchingDims !=
nullptr) {
2966 mismatchingDims->first.dim = srcDim;
2967 mismatchingDims->first.isScalable = srcDimScalableFlag;
2969 mismatchingDims->second.dim = dstDim;
2970 mismatchingDims->second.isScalable = dstDimScalableFlag;
2979LogicalResult BroadcastOp::verify() {
2980 std::pair<VectorDim, VectorDim> mismatchingDims;
2982 getSourceType(), getResultVectorType(), &mismatchingDims);
2986 return emitOpError(
"source rank higher than destination rank");
2989 << (mismatchingDims.first.isScalable ?
"[" :
"")
2990 << mismatchingDims.first.dim
2991 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2992 << (mismatchingDims.second.isScalable ?
"[" :
"")
2993 << mismatchingDims.second.dim
2994 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2997 return emitOpError(
"source type is not a vector");
2998 llvm_unreachable(
"unexpected vector.broadcast op error");
3005 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3009 VectorType srcType = srcShapeCast.getSourceVectorType();
3010 VectorType destType = broadcastOp.getResultVectorType();
3018 srcShapeCast.getResultVectorType().getShape();
3021 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3022 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3023 shapecastShape.take_back(numTrailingDims)))
3026 assert(all_of(srcShape.drop_back(numTrailingDims),
3027 [](
int64_t E) { return E == 1; }) &&
3028 all_of(shapecastShape.drop_back(numTrailingDims),
3029 [](
int64_t E) { return E == 1; }) &&
3030 "ill-formed shape_cast");
3032 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3036OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3037 if (getSourceType() == getResultVectorType())
3042 if (!adaptor.getSource())
3044 auto vectorType = getResultVectorType();
3045 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3046 if (vectorType.getElementType() != attr.getType())
3050 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3051 if (vectorType.getElementType() != attr.getType())
3055 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3057 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3065struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3068 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3069 PatternRewriter &rewriter)
const override {
3070 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3074 broadcastOp.getResultVectorType(),
3075 srcBroadcast.getSource());
3081void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3082 MLIRContext *context) {
3085 results.
add<BroadcastFolder>(context);
3092LogicalResult ShuffleOp::verify() {
3093 VectorType resultType = getResultVectorType();
3094 VectorType v1Type = getV1VectorType();
3095 VectorType v2Type = getV2VectorType();
3097 int64_t resRank = resultType.getRank();
3098 int64_t v1Rank = v1Type.getRank();
3099 int64_t v2Rank = v2Type.getRank();
3100 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3101 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3102 if (!wellFormed0DCase && !wellFormedNDCase)
3106 for (int64_t r = 1; r < v1Rank; ++r) {
3107 int64_t resDim = resultType.getDimSize(r);
3108 int64_t v1Dim = v1Type.getDimSize(r);
3109 int64_t v2Dim = v2Type.getDimSize(r);
3110 if (resDim != v1Dim || v1Dim != v2Dim)
3114 ArrayRef<int64_t> mask = getMask();
3115 int64_t maskLength = mask.size();
3116 if (maskLength <= 0)
3118 if (maskLength != resultType.getDimSize(0))
3121 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3122 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3123 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3125 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3131ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3132 ShuffleOp::Adaptor adaptor,
3133 SmallVectorImpl<Type> &inferredReturnTypes) {
3134 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3135 auto v1Rank = v1Type.getRank();
3138 SmallVector<int64_t, 4> shape;
3139 shape.reserve(v1Rank);
3140 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3143 llvm::append_range(shape, v1Type.getShape().drop_front());
3144 inferredReturnTypes.push_back(
3145 VectorType::get(shape, v1Type.getElementType()));
3149template <
typename T>
3152 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3153 return value == expected++;
3157OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3158 auto v1Type = getV1VectorType();
3159 auto v2Type = getV2VectorType();
3161 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3162 "Vector shuffle does not support scalable vectors");
3166 if (v1Type.getRank() == 0)
3170 auto mask = getMask();
3177 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3178 if (!v1Attr || !v2Attr)
3182 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3183 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3184 if (isV1Poison && isV2Poison)
3189 if (v1Type.getRank() != 1)
3195 SmallVector<Attribute> v1Elements, v2Elements;
3196 Attribute poisonElement;
3198 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3201 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3202 poisonElement = v2Elements[0];
3205 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3208 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3209 poisonElement = v1Elements[0];
3212 SmallVector<Attribute> results;
3213 int64_t v1Size = v1Type.getDimSize(0);
3214 for (int64_t maskIdx : mask) {
3215 Attribute indexedElm;
3217 if (maskIdx == ShuffleOp::kPoisonIndex) {
3218 indexedElm = poisonElement;
3220 if (maskIdx < v1Size)
3221 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3223 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3226 results.push_back(indexedElm);
3236struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3239 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3240 PatternRewriter &rewriter)
const override {
3241 VectorType v1VectorType = shuffleOp.getV1VectorType();
3242 ArrayRef<int64_t> mask = shuffleOp.getMask();
3243 if (v1VectorType.getRank() > 0)
3245 if (mask.size() != 1)
3247 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3265static Value getScalarSplatSource(Value value) {
3271 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3278 if (isa<VectorType>(
broadcast.getSourceType()))
3286class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3290 LogicalResult matchAndRewrite(ShuffleOp op,
3291 PatternRewriter &rewriter)
const override {
3292 Value splat = getScalarSplatSource(op.getV1());
3293 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3303class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3307 LogicalResult matchAndRewrite(ShuffleOp op,
3308 PatternRewriter &rewriter)
const override {
3309 VectorType resultType = op.getResultVectorType();
3310 if (resultType.isScalable())
3312 op,
"ShuffleOp can't represent a scalable interleave");
3314 if (resultType.getRank() != 1)
3316 op,
"ShuffleOp can't represent an n-D interleave");
3318 VectorType sourceType = op.getV1VectorType();
3319 if (sourceType != op.getV2VectorType() ||
3320 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3322 op,
"ShuffleOp types don't match an interleave");
3325 ArrayRef<int64_t> shuffleMask = op.getMask();
3326 int64_t resultVectorSize = resultType.getNumElements();
3327 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3328 int64_t maskValueA = shuffleMask[i * 2];
3329 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3330 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3332 "ShuffleOp mask not interleaving");
3342void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3343 MLIRContext *context) {
3344 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3352void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3354 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3357void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3358 Value source, Value dest) {
3359 auto vectorTy = cast<VectorType>(dest.
getType());
3360 build(builder,
result, source, dest,
3361 SmallVector<int64_t>(vectorTy.getRank(), 0));
3364void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3365 Value source, Value dest, int64_t position) {
3366 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3369void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3370 Value source, Value dest, OpFoldResult position) {
3371 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3374void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3375 Value source, Value dest,
3376 ArrayRef<int64_t> position) {
3377 SmallVector<OpFoldResult> posVals;
3378 posVals.reserve(position.size());
3379 llvm::transform(position, std::back_inserter(posVals),
3381 build(builder,
result, source, dest, posVals);
3384void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3385 Value source, Value dest,
3386 ArrayRef<OpFoldResult> position) {
3387 SmallVector<int64_t> staticPos;
3388 SmallVector<Value> dynamicPos;
3390 build(builder,
result, source, dest, dynamicPos,
3394LogicalResult InsertOp::verify() {
3395 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3396 if (srcTy.getRank() == 0)
3398 "expected a scalar instead of a 0-d vector as the source operand");
3400 SmallVector<OpFoldResult> position = getMixedPosition();
3401 auto destVectorType = getDestVectorType();
3402 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3404 "expected position attribute of rank no greater than dest vector rank");
3405 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3406 if (srcVectorType &&
3407 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3408 static_cast<unsigned>(destVectorType.getRank())))
3409 return emitOpError(
"expected position attribute rank + source rank to "
3410 "match dest vector rank");
3411 if (!srcVectorType &&
3412 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3414 "expected position attribute rank to match the dest vector rank");
3415 for (
auto [idx, pos] : llvm::enumerate(position)) {
3416 if (
auto attr = dyn_cast<Attribute>(pos)) {
3417 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3419 destVectorType.getDimSize(idx))) {
3420 return emitOpError(
"expected position attribute #")
3422 <<
" to be a non-negative integer smaller than the "
3424 "dest vector dimension";
3437 assert(positions.size() <= completePositions.size() &&
3438 "positions size must be less than or equal to destTy rank");
3439 copy(positions, completePositions.begin());
3447class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3451 LogicalResult matchAndRewrite(InsertOp insertOp,
3452 PatternRewriter &rewriter)
const override {
3454 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3455 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3456 srcVecType.getNumElements())
3459 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3465class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3469 LogicalResult matchAndRewrite(InsertOp op,
3470 PatternRewriter &rewriter)
const override {
3472 Value splat = getScalarSplatSource(op.getValueToStore());
3473 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3501class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3504 LogicalResult matchAndRewrite(InsertOp op,
3505 PatternRewriter &rewriter)
const override {
3507 VectorType destTy = op.getDestVectorType();
3508 if (destTy.isScalable())
3511 for (Operation *user : op.getResult().getUsers())
3512 if (
auto insertOp = dyn_cast<InsertOp>(user))
3513 if (insertOp.getDest() == op.getResult())
3516 InsertOp currentOp = op;
3517 SmallVector<InsertOp> chainInsertOps;
3520 if (currentOp.hasDynamicPosition())
3523 chainInsertOps.push_back(currentOp);
3524 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3527 if (currentOp && !currentOp->hasOneUse())
3531 int64_t vectorSize = destTy.getNumElements();
3532 int64_t initializedCount = 0;
3533 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3534 SmallVector<int64_t> pendingInsertPos;
3535 SmallVector<int64_t> pendingInsertSize;
3536 SmallVector<Value> pendingInsertValues;
3538 for (
auto insertOp : chainInsertOps) {
3540 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3544 int64_t insertBeginPosition =
3549 int64_t insertSize = 1;
3550 if (
auto srcVectorType =
3551 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3552 insertSize = srcVectorType.getNumElements();
3554 assert(insertBeginPosition + insertSize <= vectorSize &&
3555 "insert would overflow the vector");
3557 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3558 insertBeginPosition + insertSize)) {
3559 if (initializedDestIdxs[index])
3561 initializedDestIdxs[index] =
true;
3567 pendingInsertPos.push_back(insertBeginPosition);
3568 pendingInsertSize.push_back(insertSize);
3569 pendingInsertValues.push_back(insertOp.getValueToStore());
3571 if (initializedCount == vectorSize)
3576 if (initializedCount != vectorSize)
3579 SmallVector<Value> elements(vectorSize);
3580 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3581 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3582 pendingInsertValues))) {
3583 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3585 if (!srcVectorType) {
3586 elements[insertBeginPosition] = valueToStore;
3590 SmallVector<Type> elementToInsertTypes(insertSize,
3591 srcVectorType.getElementType());
3593 auto elementsToInsert = vector::ToElementsOp::create(
3594 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3595 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3596 elements[insertBeginPosition + linearIdx] =
3597 elementsToInsert.getResult(linearIdx);
3611 int64_t maxVectorSizeFoldThreshold) {
3612 if (insertOp.hasDynamicPosition())
3615 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3623 VectorType destTy = insertOp.getDestVectorType();
3624 if (destTy.isScalable())
3628 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3629 !insertOp->hasOneUse())
3636 Type destEltType = destTy.getElementType();
3640 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3641 for (
auto value : denseSource.getValues<
Attribute>())
3647 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3648 copy(insertedValues, allValues.begin() + insertBeginPosition);
3657 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3661 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3664 insertOp.
setOperand(1, destInsert.getDest());
3665 return insertOp.getResult();
3668void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3669 MLIRContext *context) {
3670 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3671 InsertChainFullyInitialized>(context);
3674OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3677 constexpr int64_t vectorSizeFoldThreshold = 256;
3681 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3682 return getValueToStore();
3686 SmallVector<Value> operands = {getValueToStore(), getDest()};
3692 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3695 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3696 vectorSizeFoldThreshold)) {
3700 return inplaceFolded;
3707void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
3708 Value source, Value dest,
3709 ArrayRef<int64_t> offsets,
3710 ArrayRef<int64_t> strides) {
3711 result.addOperands({source, dest});
3715 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
3717 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
3722template <
typename OpType>
3726 StringRef attrName) {
3727 if (arrayAttr.size() >
shape.size())
3728 return op.emitOpError(
"expected ")
3729 << attrName <<
" attribute of rank no greater than vector rank";
3736template <
typename OpType>
3740 bool halfOpen =
true) {
3741 for (
auto attr : arrayAttr) {
3742 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3746 if (val < min || val >= upper)
3747 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3748 <<
min <<
", " << upper <<
")";
3756template <
typename OpType>
3761 for (
auto [
index, attrDimPair] :
3762 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
3763 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3767 if (val < min || val >=
max)
3768 return op.emitOpError(
"expected ")
3769 << attrName <<
" dimension " <<
index <<
" to be confined to ["
3770 <<
min <<
", " <<
max <<
")";
3780template <
typename OpType>
3785 assert(arrayAttr1.size() <=
shape.size());
3786 assert(arrayAttr2.size() <=
shape.size());
3787 for (
auto [
index, it] :
3788 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
3789 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3790 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3794 if (val1 + val2 < 0 || val1 + val2 >=
max)
3795 return op.emitOpError(
"expected sum(")
3796 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
3797 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3805 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3807 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3810LogicalResult InsertStridedSliceOp::verify() {
3811 auto sourceVectorType = getSourceVectorType();
3812 auto destVectorType = getDestVectorType();
3813 auto offsets = getOffsetsAttr();
3814 auto strides = getStridesAttr();
3815 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3817 "expected offsets of same size as destination vector rank");
3818 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3819 return emitOpError(
"expected strides of same size as source vector rank");
3820 if (sourceVectorType.getRank() > destVectorType.getRank())
3822 "expected source rank to be no greater than destination rank");
3824 auto sourceShape = sourceVectorType.getShape();
3825 auto destShape = destVectorType.getShape();
3826 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3827 destShape.size() - sourceShape.size(), 0);
3828 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3829 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3830 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3839 offName,
"source vector shape",
3843 unsigned rankDiff = destShape.size() - sourceShape.size();
3844 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3845 if (sourceVectorType.getScalableDims()[idx] !=
3846 destVectorType.getScalableDims()[idx + rankDiff]) {
3847 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3850 if (sourceVectorType.getScalableDims()[idx]) {
3851 auto sourceSize = sourceShape[idx];
3852 auto destSize = destShape[idx + rankDiff];
3853 if (sourceSize != destSize) {
3856 << (
" to match the corresponding base size from the input "
3858 << sourceSize << (
" vs ") << destSize << (
")");
3868class FoldInsertStridedSliceSplat final
3869 :
public OpRewritePattern<InsertStridedSliceOp> {
3873 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3874 PatternRewriter &rewriter)
const override {
3876 auto dst = insertStridedSliceOp.getDest();
3877 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3878 if (!splat || getScalarSplatSource(dst) != splat)
3881 rewriter.
replaceOp(insertStridedSliceOp, dst);
3888class FoldInsertStridedSliceOfExtract final
3889 :
public OpRewritePattern<InsertStridedSliceOp> {
3893 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3894 PatternRewriter &rewriter)
const override {
3895 auto extractStridedSliceOp =
3896 insertStridedSliceOp.getValueToStore()
3897 .getDefiningOp<vector::ExtractStridedSliceOp>();
3899 if (!extractStridedSliceOp)
3902 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3906 if (extractStridedSliceOp.getStrides() !=
3907 insertStridedSliceOp.getStrides() ||
3908 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3911 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3918class InsertStridedSliceConstantFolder final
3919 :
public OpRewritePattern<InsertStridedSliceOp> {
3925 static constexpr int64_t vectorSizeFoldThreshold = 256;
3927 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3928 PatternRewriter &rewriter)
const override {
3932 Attribute vectorDestCst;
3936 VectorType destTy = destVector.getType();
3937 if (destTy.isScalable())
3941 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3942 !destVector.hasOneUse())
3946 Attribute sourceCst;
3951 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3955 if (op.hasNonUnitStrides())
3958 VectorType sliceVecTy = sourceValue.getType();
3959 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3960 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3961 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
3962 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
3970 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3971 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3972 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3973 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3974 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3975 MutableArrayRef<int64_t> currSlicePosition(
3976 currDestPosition.begin() + rankDifference, currDestPosition.end());
3977 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3980 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3981 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3982 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3983 "Invalid slice element");
3984 newValues[linearizedPosition] = *sliceValuesIt;
3997void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3998 RewritePatternSet &results, MLIRContext *context) {
3999 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4000 InsertStridedSliceConstantFolder>(context);
4003OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4004 if (getSourceVectorType() == getDestVectorType())
4005 return getValueToStore();
4014void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4015 Value
lhs, Value
rhs, Value acc) {
4020void OuterProductOp::print(OpAsmPrinter &p) {
4021 p <<
" " << getLhs() <<
", " << getRhs();
4023 p <<
", " << getAcc();
4026 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4029ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4030 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4037 if (operandsInfo.size() < 2)
4039 "expected at least 2 operands");
4040 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4041 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4044 "expected vector type for operand #1");
4048 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4049 vRHS.getScalableDims()[0]};
4050 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4051 vLHS.getElementType(), scalableDimsRes);
4054 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4055 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4059 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4060 result.attributes.append(
4061 OuterProductOp::getKindAttrName(
result.name),
4062 CombiningKindAttr::get(
result.getContext(),
4063 OuterProductOp::getDefaultKind()));
4069 (operandsInfo.size() > 2 &&
4074LogicalResult OuterProductOp::verify() {
4075 Type tRHS = getOperandTypeRHS();
4076 VectorType vLHS = getOperandVectorTypeLHS(),
4077 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4078 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4080 if (vLHS.getRank() != 1)
4081 return emitOpError(
"expected 1-d vector for operand #1");
4085 if (vRHS.getRank() != 1)
4086 return emitOpError(
"expected 1-d vector for operand #2");
4087 if (vRES.getRank() != 2)
4089 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4090 return emitOpError(
"expected #1 operand dim to match result dim #1");
4091 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4092 return emitOpError(
"expected #2 operand dim to match result dim #2");
4093 if (vLHS.isScalable() && !vRHS.isScalable()) {
4097 "expected either both or only #2 operand dim to be scalable");
4101 if (vRES.getRank() != 1)
4103 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4104 return emitOpError(
"expected #1 operand dim to match result dim #1");
4107 if (vACC && vACC != vRES)
4108 return emitOpError(
"expected operand #3 of same type as result type");
4112 return emitOpError(
"unsupported outerproduct type");
4121Type OuterProductOp::getExpectedMaskType() {
4122 auto vecType = this->getResultVectorType();
4123 return VectorType::get(vecType.getShape(),
4124 IntegerType::get(vecType.getContext(), 1),
4125 vecType.getScalableDims());
4139 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4141 shape.reserve(vectorType.getRank());
4143 for (
unsigned e = offsets.size(); idx < e; ++idx)
4144 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4145 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4146 shape.push_back(vectorType.getShape()[idx]);
4148 return VectorType::get(
shape, vectorType.getElementType(),
4149 vectorType.getScalableDims());
4152void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4153 Value source, ArrayRef<int64_t> offsets,
4154 ArrayRef<int64_t> sizes,
4155 ArrayRef<int64_t> strides) {
4156 result.addOperands(source);
4162 offsetsAttr, sizesAttr, stridesAttr));
4163 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4165 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4167 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4171LogicalResult ExtractStridedSliceOp::verify() {
4172 auto type = getSourceVectorType();
4173 auto offsets = getOffsetsAttr();
4174 auto sizes = getSizesAttr();
4175 auto strides = getStridesAttr();
4176 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4178 "expected offsets, sizes and strides attributes of same size");
4180 auto shape = type.getShape();
4181 auto offName = getOffsetsAttrName();
4182 auto sizesName = getSizesAttrName();
4183 auto stridesName = getStridesAttrName();
4199 shape, offName, sizesName,
4204 offsets, sizes, strides);
4205 if (getResult().
getType() != resultType)
4206 return emitOpError(
"expected result type to be ") << resultType;
4208 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4209 if (type.getScalableDims()[idx]) {
4210 auto inputDim = type.getShape()[idx];
4211 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4212 if (inputDim != inputSize)
4215 << (
" to match the corresponding base size from the input "
4217 << inputSize << (
" vs ") << inputDim << (
")");
4230 auto getElement = [](
ArrayAttr array,
int idx) {
4231 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4233 ArrayAttr extractOffsets = op.getOffsets();
4236 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4238 if (op.getSourceVectorType().getRank() !=
4239 insertOp.getSourceVectorType().getRank())
4241 ArrayAttr insertOffsets = insertOp.getOffsets();
4242 ArrayAttr insertStrides = insertOp.getStrides();
4245 if (extractOffsets.size() > insertOffsets.size())
4247 bool patialoverlap =
false;
4248 bool disjoint =
false;
4250 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4251 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4253 int64_t start = getElement(insertOffsets, dim);
4254 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4255 int64_t offset = getElement(extractOffsets, dim);
4256 int64_t size = getElement(extractSizes, dim);
4258 if (start <= offset && offset < end) {
4261 if (offset + size > end)
4262 patialoverlap =
true;
4263 offsetDiffs.push_back(offset - start);
4270 if (!disjoint && !patialoverlap) {
4271 op.setOperand(insertOp.getValueToStore());
4274 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4280 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4295 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4300 if (op.hasNonUnitStrides())
4303 VectorType sourceVecTy = op.getSourceVectorType();
4307 VectorType sliceVecTy = op.getType();
4309 int64_t rank = sliceVecTy.getRank();
4321 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4323 sliceValues.reserve(sliceVecTy.getNumElements());
4327 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4329 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4330 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4332 assert(
static_cast<int64_t>(sliceValues.size()) ==
4333 sliceVecTy.getNumElements() &&
4334 "Invalid number of slice elements");
4338OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4339 if (getSourceVectorType() == getResult().
getType())
4346 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4353void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4371class StridedSliceCreateMaskFolder final
4372 :
public OpRewritePattern<ExtractStridedSliceOp> {
4376 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4377 PatternRewriter &rewriter)
const override {
4378 Location loc = extractStridedSliceOp.getLoc();
4382 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4386 if (extractStridedSliceOp.hasNonUnitStrides())
4389 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4391 SmallVector<int64_t> sliceOffsets;
4394 SmallVector<int64_t> sliceSizes;
4398 SmallVector<Value> sliceMaskDimSizes;
4399 sliceMaskDimSizes.reserve(maskDimSizes.size());
4403 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4404 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4408 IntegerAttr offsetAttr =
4410 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4411 Value sliceMaskDimSize =
4412 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4413 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4418 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4422 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4430class StridedSliceConstantMaskFolder final
4431 :
public OpRewritePattern<ExtractStridedSliceOp> {
4435 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4436 PatternRewriter &rewriter)
const override {
4439 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4440 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4441 if (!constantMaskOp)
4444 if (extractStridedSliceOp.hasNonUnitStrides())
4447 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4449 SmallVector<int64_t> sliceOffsets;
4452 SmallVector<int64_t> sliceSizes;
4456 SmallVector<int64_t> sliceMaskDimSizes;
4457 sliceMaskDimSizes.reserve(maskDimSizes.size());
4458 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4459 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4460 int64_t sliceMaskDimSize = std::max(
4461 static_cast<int64_t
>(0),
4462 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4463 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4466 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4467 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4468 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4471 if (llvm::is_contained(sliceMaskDimSizes, 0))
4472 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4477 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4485class StridedSliceBroadcast final
4486 :
public OpRewritePattern<ExtractStridedSliceOp> {
4490 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4491 PatternRewriter &rewriter)
const override {
4497 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4498 auto dstVecType = llvm::cast<VectorType>(op.getType());
4499 unsigned dstRank = dstVecType.getRank();
4500 unsigned rankDiff = dstRank - srcRank;
4504 bool needsSlice =
false;
4505 for (
unsigned i = 0; i < srcRank; i++) {
4506 if (srcVecType.getDimSize(i) != 1 &&
4507 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4514 SmallVector<int64_t> offsets =
4516 SmallVector<int64_t> sizes =
4518 for (
unsigned i = 0; i < srcRank; i++) {
4519 if (srcVecType.getDimSize(i) == 1) {
4527 source = ExtractStridedSliceOp::create(
4528 rewriter, op->getLoc(), source, offsets, sizes,
4537class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4541 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4542 PatternRewriter &rewriter)
const override {
4544 Value splat = getScalarSplatSource(op.getSource());
4568class ContiguousExtractStridedSliceToExtract final
4569 :
public OpRewritePattern<ExtractStridedSliceOp> {
4573 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4574 PatternRewriter &rewriter)
const override {
4575 if (op.hasNonUnitStrides())
4577 Value source = op.getOperand();
4578 auto sourceType = cast<VectorType>(source.
getType());
4579 if (sourceType.isScalable() || sourceType.getRank() == 0)
4588 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4589 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4596 if (numOffsets == 0)
4601 if (numOffsets == sourceType.getRank() &&
4602 static_cast<int>(sizes.size()) == sourceType.getRank())
4606 for (
int i = 0; i < numOffsets; ++i) {
4614 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4615 sizes[numOffsets] == 1) {
4620 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4621 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4630void ExtractStridedSliceOp::getCanonicalizationPatterns(
4631 RewritePatternSet &results, MLIRContext *context) {
4634 results.
add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4635 StridedSliceBroadcast, StridedSliceSplat,
4636 ContiguousExtractStridedSliceToExtract>(context);
4644void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4645 VectorType vectorType, Value source,
4647 AffineMapAttr permutationMapAttr,
4650 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4652 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4653 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4654 *padding, Value(), inBoundsAttr);
4658void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4659 VectorType vectorType, Value source,
4661 AffineMap permutationMap,
4662 std::optional<ArrayRef<bool>> inBounds) {
4663 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4664 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4667 SmallVector<bool>(vectorType.getRank(),
false));
4668 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4670 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4671 build(builder,
result, vectorType, source,
indices, *padding,
4672 permutationMapAttr, inBoundsAttr);
4676void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4677 VectorType vectorType, Value source,
4679 std::optional<ArrayRef<bool>> inBounds) {
4681 llvm::cast<ShapedType>(source.
getType()), vectorType);
4682 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4683 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4686 SmallVector<bool>(vectorType.getRank(),
false));
4687 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4689 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4690 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4692 Value(), inBoundsAttr);
4695template <
typename EmitFun>
4699 for (
auto expr : permutationMap.
getResults()) {
4700 auto dim = dyn_cast<AffineDimExpr>(expr);
4701 auto zero = dyn_cast<AffineConstantExpr>(expr);
4703 if (zero.getValue() != 0) {
4705 "requires a projected permutation_map (at most one dim or the zero "
4706 "constant can appear in each result)");
4711 return emitOpError(
"requires a projected permutation_map (at most one "
4712 "dim or the zero constant can appear in each result)");
4714 if (seen[dim.getPosition()]) {
4716 "requires a permutation_map that is a permutation (found one dim "
4717 "used more than once)");
4719 seen[dim.getPosition()] =
true;
4726 VectorType vectorType, VectorType maskType,
4727 VectorType inferredMaskType,
AffineMap permutationMap,
4729 if (op->hasAttr(
"masked")) {
4730 return op->emitOpError(
"masked attribute has been removed. "
4731 "Use in_bounds instead.");
4734 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4735 return op->emitOpError(
4736 "requires source to be a memref or ranked tensor type");
4738 auto elementType = shapedType.getElementType();
4740 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4742 unsigned sourceVecSize =
4744 vectorElementType.getShape().back();
4745 unsigned resultVecSize =
4747 vectorType.getShape().back();
4748 if (resultVecSize % sourceVecSize != 0)
4749 return op->emitOpError(
4750 "requires the bitwidth of the minor 1-D vector to be an integral "
4751 "multiple of the bitwidth of the minor 1-D vector of the source");
4753 unsigned sourceVecEltRank = vectorElementType.getRank();
4754 unsigned resultVecRank = vectorType.getRank();
4755 if (sourceVecEltRank > resultVecRank)
4756 return op->emitOpError(
4757 "requires source vector element and vector result ranks to match.");
4758 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4761 return op->emitOpError(
"requires a permutation_map with result dims of "
4762 "the same rank as the vector type");
4765 return op->emitOpError(
"does not support masks with vector element type");
4768 unsigned minorSize =
4769 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4770 unsigned resultVecSize =
4773 return op->emitOpError(
4774 "requires the bitwidth of the minor 1-D vector to be an integral "
4775 "multiple of the bitwidth of the source element type");
4779 return op->emitOpError(
"requires a permutation_map with result dims of "
4780 "the same rank as the vector type");
4784 return op->emitOpError(
"requires permutation_map without symbols");
4786 if (permutationMap.
getNumInputs() != shapedType.getRank())
4787 return op->emitOpError(
"requires a permutation_map with input dims of the "
4788 "same rank as the source type");
4790 if (maskType && maskType != inferredMaskType)
4791 return op->emitOpError(
"inferred mask type (")
4792 << inferredMaskType <<
") and mask operand type (" << maskType
4796 return op->emitOpError(
"expects the in_bounds attr of same rank "
4797 "as permutation_map results: ")
4798 << AffineMapAttr::get(permutationMap)
4799 <<
" vs inBounds of size: " << inBounds.size();
4806 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4807 if (op.getPermutationMap().isMinorIdentity())
4808 elidedAttrs.push_back(op.getPermutationMapAttrName());
4810 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4811 elidedAttrs.push_back(op.getInBoundsAttrName());
4815void TransferReadOp::print(OpAsmPrinter &p) {
4818 p <<
", " << getMask();
4825 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
4827 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4832 if (maskShape.empty())
4833 maskShape.push_back(1);
4838 return VectorType::get(maskShape, i1Type, scalableDims);
4855 if (hasMask.succeeded()) {
4862 if (types.size() != 2)
4863 return parser.
emitError(typesLoc,
"requires two types");
4865 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4866 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4867 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4868 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4870 return parser.
emitError(typesLoc,
"requires vector type");
4871 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
4875 if (shapedType.getRank() <
4878 "expected a custom permutation_map when "
4879 "rank(source) != rank(destination)");
4881 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4883 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4885 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
4886 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
4887 if (!inBoundsAttr) {
4888 result.addAttribute(inBoundsAttrName,
4897 if (hasMask.succeeded()) {
4898 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4900 maskInfo.
location,
"does not support masks with vector element type");
4903 "expected the same rank for the vector and the "
4904 "results of the permutation map");
4912 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4914 {1, static_cast<int32_t>(indexInfo.size()), 1,
4915 static_cast<int32_t>(hasMask.succeeded())}));
4919LogicalResult TransferReadOp::verify() {
4921 ShapedType shapedType = getShapedType();
4923 VectorType maskType = getMaskType();
4924 auto paddingType = getPadding().getType();
4925 auto permutationMap = getPermutationMap();
4926 VectorType inferredMaskType =
4929 auto sourceElementType = shapedType.getElementType();
4931 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4932 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4935 shapedType, vectorType, maskType,
4936 inferredMaskType, permutationMap, getInBounds())))
4939 if (
auto sourceVectorElementType =
4940 llvm::dyn_cast<VectorType>(sourceElementType)) {
4943 if (sourceVectorElementType != paddingType)
4945 "requires source element type and padding type to match.");
4949 if (!VectorType::isValidElementType(paddingType))
4950 return emitOpError(
"requires valid padding vector elemental type");
4953 if (paddingType != sourceElementType)
4955 "requires formal padding and source of the same elemental type");
4966Type TransferReadOp::getExpectedMaskType() {
4973VectorType TransferReadOp::getVectorType() {
4974 return cast<VectorType>(getVector().
getType());
4977template <
typename TransferOp>
4981 if (op.getShapedType().isDynamicDim(indicesIdx))
4985 if (!cstOp.has_value())
4988 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4989 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4991 return cstOp.value() + vectorSize <= sourceSize;
4994template <
typename TransferOp>
4998 if (op.getTransferRank() == 0)
5003 newInBounds.reserve(op.getTransferRank());
5008 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5010 if (op.isDimInBounds(i)) {
5011 newInBounds.push_back(
true);
5016 bool inBounds =
false;
5017 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5020 dimExpr.getPosition());
5021 nonBcastDims.push_back(i);
5024 newInBounds.push_back(inBounds);
5032 bool allNonBcastDimsInBounds = llvm::all_of(
5033 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5034 if (allNonBcastDimsInBounds) {
5037 newInBounds[idx] =
true;
5045 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5049template <
typename TransferOp>
5051 auto mask = op.getMask();
5058 op.getMaskMutable().clear();
5072static Value foldRAW(TransferReadOp readOp) {
5073 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5075 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5078 return defWrite.getVector();
5080 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5081 cast<VectorTransferOpInterface>(readOp.getOperation())))
5083 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5088OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5089 if (Value vec = foldRAW(*
this))
5100 return OpFoldResult();
5103std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5107void TransferReadOp::getEffects(
5108 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5110 if (llvm::isa<MemRefType>(getShapedType()))
5111 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5112 SideEffects::DefaultResource::get());
5116 if (hasPureTensorSemantics())
5144struct TransferReadAfterWriteToBroadcast
5145 :
public OpRewritePattern<TransferReadOp> {
5148 LogicalResult matchAndRewrite(TransferReadOp readOp,
5149 PatternRewriter &rewriter)
const override {
5150 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5154 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5157 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5161 if (readOp.getTransferChunkAccessed() !=
5162 defWrite.getTransferChunkAccessed())
5181 if (readOp.getMask() || defWrite.getMask())
5184 if (readOp.getIndices() != defWrite.getIndices())
5187 Value vec = defWrite.getVector();
5191 AffineMap map = readMap.
compose(writeMap);
5196 SmallVector<unsigned> permutation;
5200 Location loc = readOp.getLoc();
5203 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
5204 SmallVector<int64_t> broadcastShape(destShape.size());
5205 SmallVector<bool> broadcastScalableFlags(destShape.size());
5206 for (
const auto &pos : llvm::enumerate(permutation)) {
5207 broadcastShape[pos.value()] = destShape[pos.index()];
5208 broadcastScalableFlags[pos.value()] =
5209 readOp.getVectorType().getScalableDims()[pos.index()];
5211 VectorType broadcastedType = VectorType::get(
5212 broadcastShape, defWrite.getVectorType().getElementType(),
5213 broadcastScalableFlags);
5214 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5215 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
5223void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5224 MLIRContext *context) {
5225 results.
add<TransferReadAfterWriteToBroadcast>(context);
5228FailureOr<std::optional<SmallVector<Value>>>
5229TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5230 if (!hasPureBufferSemantics())
5241void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5243 AffineMapAttr permutationMapAttr,
5246 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5247 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5248 mask, inBoundsAttr);
5252void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5254 AffineMapAttr permutationMapAttr,
5256 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5257 Value(), inBoundsAttr);
5262void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5264 AffineMap permutationMap,
5265 std::optional<ArrayRef<bool>> inBounds) {
5266 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5268 (inBounds && !inBounds.value().empty())
5271 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5272 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5273 Value(), inBoundsAttr);
5278void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5280 std::optional<ArrayRef<bool>> inBounds) {
5281 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5283 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5284 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5287ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5288 OperationState &
result) {
5291 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5292 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5293 SmallVector<Type, 2> types;
5294 OpAsmParser::UnresolvedOperand maskInfo;
5300 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5305 if (types.size() != 2)
5306 return parser.
emitError(typesLoc,
"requires two types");
5308 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5310 return parser.
emitError(typesLoc,
"requires vector type");
5311 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5312 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5313 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5314 auto permMapAttrName =
5315 TransferWriteOp::getPermutationMapAttrName(
result.name);
5316 auto permMapAttr =
result.attributes.get(permMapAttrName);
5319 if (shapedType.getRank() <
5322 "expected a custom permutation_map when "
5323 "rank(source) != rank(destination)");
5325 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5327 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5329 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5330 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5331 if (!inBoundsAttr) {
5332 result.addAttribute(inBoundsAttrName,
5340 if (hasMask.succeeded()) {
5341 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5343 maskInfo.
location,
"does not support masks with vector element type");
5346 "expected the same rank for the vector and the "
5347 "results of the permutation map");
5353 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5355 {1, 1, static_cast<int32_t>(indexInfo.size()),
5356 static_cast<int32_t>(hasMask.succeeded())}));
5357 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5361void TransferWriteOp::print(OpAsmPrinter &p) {
5364 p <<
", " << getMask();
5369LogicalResult TransferWriteOp::verify() {
5371 ShapedType shapedType = getShapedType();
5373 VectorType maskType = getMaskType();
5374 auto permutationMap = getPermutationMap();
5375 VectorType inferredMaskType =
5379 if (llvm::size(
getIndices()) != shapedType.getRank())
5380 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5384 if (hasBroadcastDim())
5385 return emitOpError(
"should not have broadcast dimensions");
5388 shapedType, vectorType, maskType,
5389 inferredMaskType, permutationMap, getInBounds())))
5402Type TransferWriteOp::getExpectedMaskType() {
5409Value TransferWriteOp::getVector() {
return getOperand(0); }
5410VectorType TransferWriteOp::getVectorType() {
5411 return cast<VectorType>(getValueToStore().
getType());
5434static LogicalResult foldReadInitWrite(TransferWriteOp write,
5435 ArrayRef<Attribute>,
5436 SmallVectorImpl<OpFoldResult> &results) {
5438 if (write.getTransferRank() == 0)
5440 auto rankedTensorType =
5441 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5443 if (!rankedTensorType)
5446 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5450 if (read.getTransferRank() == 0)
5453 if (!read.getPermutationMap().isMinorIdentity() ||
5454 !write.getPermutationMap().isMinorIdentity())
5457 if (read.getTransferRank() != write.getTransferRank())
5460 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5463 if (read.getBase().getType() != rankedTensorType)
5466 if (read.getVectorType() != write.getVectorType())
5469 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5472 auto isNotConstantZero = [](Value v) {
5474 return !cstOp.has_value() || cstOp.value() != 0;
5476 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5477 llvm::any_of(write.getIndices(), isNotConstantZero))
5480 results.push_back(read.getBase());
5484static bool checkSameValueWAR(vector::TransferReadOp read,
5485 vector::TransferWriteOp write) {
5486 return read.getBase() == write.getBase() &&
5487 read.getIndices() == write.getIndices() &&
5488 read.getPermutationMap() == write.getPermutationMap() &&
5489 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5506static LogicalResult foldWAR(TransferWriteOp write,
5507 SmallVectorImpl<OpFoldResult> &results) {
5508 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5510 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5514 if (!checkSameValueWAR(read, write))
5516 results.push_back(read.getBase());
5520LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5521 SmallVectorImpl<OpFoldResult> &results) {
5522 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5524 if (succeeded(foldWAR(*
this, results)))
5536std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5540void TransferWriteOp::getEffects(
5541 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5543 if (llvm::isa<MemRefType>(getShapedType()))
5544 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5545 SideEffects::DefaultResource::get());
5549 if (hasPureTensorSemantics())
5579class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
5582 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5583 PatternRewriter &rewriter)
const override {
5584 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5586 vector::TransferWriteOp writeToModify = writeOp;
5588 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5592 writeToModify.getBaseMutable().assign(defWrite.getBase());
5597 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5598 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5602 if (!defWrite->hasOneUse())
5604 writeToModify = defWrite;
5605 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5634struct SwapExtractSliceOfTransferWrite
5635 :
public OpRewritePattern<tensor::InsertSliceOp> {
5639 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5640 PatternRewriter &rewriter)
const override {
5641 if (!insertOp.hasUnitStride())
5644 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5645 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5647 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5648 if (!transferOp || !transferOp->hasOneUse())
5653 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5655 "use-def chain is rank-reducing");
5659 if (!extractOp.hasZeroOffset()) {
5661 "ExtractSliceOp has non-zero offset");
5665 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5666 return getConstantIntValue(value) == static_cast<int64_t>(0);
5669 "TranferWriteOp has non-zero offset");
5673 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5675 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5678 for (
auto [insertSize, extractSize] :
5679 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5682 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5687 assert(transferOp.getVectorType().hasStaticShape() &&
5688 "expected vector to have a static shape");
5689 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
5691 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5692 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5694 insertOp,
"TransferWriteOp may not write the full tensor.");
5699 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
5700 auto newExtractOp = tensor::ExtractSliceOp::create(
5701 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5702 insertOp.getDest(), insertOp.getMixedOffsets(),
5703 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5704 auto newTransferWriteOp = TransferWriteOp::create(
5705 rewriter, transferOp.getLoc(), transferOp.getVector(),
5706 newExtractOp.getResult(), transferOp.getIndices(),
5707 transferOp.getPermutationMapAttr(),
5710 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5718void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5719 MLIRContext *context) {
5720 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5723FailureOr<std::optional<SmallVector<Value>>>
5724TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5725 if (!hasPureBufferSemantics())
5735static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5737 MemRefType memRefTy) {
5740 if (!vecTy.isScalable() &&
5741 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5744 if (!memRefTy.isLastDimUnitStride())
5745 return op->
emitOpError(
"most minor memref dim must have unit stride");
5749LogicalResult vector::LoadOp::verify() {
5753 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5756 if (memRefTy.getRank() < resVecTy.getRank())
5758 "destination memref has lower rank than the result vector");
5761 Type memElemTy = memRefTy.getElementType();
5762 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5763 if (memVecTy != resVecTy)
5764 return emitOpError(
"base memref and result vector types should match");
5765 memElemTy = memVecTy.getElementType();
5768 if (resVecTy.getElementType() != memElemTy)
5769 return emitOpError(
"base and result element types should match");
5770 if (llvm::size(
getIndices()) != memRefTy.getRank())
5771 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5775OpFoldResult LoadOp::fold(FoldAdaptor) {
5778 return OpFoldResult();
5781std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5785FailureOr<std::optional<SmallVector<Value>>>
5786LoadOp::bubbleDownCasts(OpBuilder &builder) {
5795LogicalResult vector::StoreOp::verify() {
5799 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5802 if (memRefTy.getRank() < valueVecTy.getRank())
5803 return emitOpError(
"source memref has lower rank than the vector to store");
5806 Type memElemTy = memRefTy.getElementType();
5807 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5808 if (memVecTy != valueVecTy)
5810 "base memref and valueToStore vector types should match");
5811 memElemTy = memVecTy.getElementType();
5814 if (valueVecTy.getElementType() != memElemTy)
5815 return emitOpError(
"base and valueToStore element type should match");
5816 if (llvm::size(
getIndices()) != memRefTy.getRank())
5817 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5821LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5822 SmallVectorImpl<OpFoldResult> &results) {
5826std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5830FailureOr<std::optional<SmallVector<Value>>>
5831StoreOp::bubbleDownCasts(OpBuilder &builder) {
5840LogicalResult MaskedLoadOp::verify() {
5841 VectorType maskVType = getMaskVectorType();
5842 VectorType passVType = getPassThruVectorType();
5846 if (resVType.getElementType() != memType.getElementType())
5847 return emitOpError(
"base and result element type should match");
5848 if (llvm::size(
getIndices()) != memType.getRank())
5849 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5850 if (resVType.getShape() != maskVType.getShape())
5851 return emitOpError(
"expected result shape to match mask shape");
5852 if (resVType != passVType)
5853 return emitOpError(
"expected pass_thru of same type as result type");
5858class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
5861 LogicalResult matchAndRewrite(MaskedLoadOp
load,
5862 PatternRewriter &rewriter)
const override {
5874 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5879void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5880 MLIRContext *context) {
5881 results.
add<MaskedLoadFolder>(context);
5884OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5887 return OpFoldResult();
5890FailureOr<std::optional<SmallVector<Value>>>
5891MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
5900LogicalResult MaskedStoreOp::verify() {
5901 VectorType maskVType = getMaskVectorType();
5905 if (valueVType.getElementType() != memType.getElementType())
5906 return emitOpError(
"base and valueToStore element type should match");
5907 if (llvm::size(
getIndices()) != memType.getRank())
5908 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5909 if (valueVType.getShape() != maskVType.getShape())
5910 return emitOpError(
"expected valueToStore shape to match mask shape");
5915class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
5918 LogicalResult matchAndRewrite(MaskedStoreOp store,
5919 PatternRewriter &rewriter)
const override {
5923 store, store.getValueToStore(), store.getBase(), store.getIndices());
5931 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5936void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5937 MLIRContext *context) {
5938 results.
add<MaskedStoreFolder>(context);
5941LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5942 SmallVectorImpl<OpFoldResult> &results) {
5946FailureOr<std::optional<SmallVector<Value>>>
5947MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
5956LogicalResult GatherOp::verify() {
5957 VectorType indVType = getIndexVectorType();
5958 VectorType maskVType = getMaskVectorType();
5960 ShapedType baseType = getBaseType();
5962 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5963 return emitOpError(
"requires base to be a memref or ranked tensor type");
5965 if (resVType.getElementType() != baseType.getElementType())
5966 return emitOpError(
"base and result element type should match");
5967 if (llvm::size(getOffsets()) != baseType.getRank())
5968 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5969 if (resVType.getShape() != indVType.getShape())
5970 return emitOpError(
"expected result dim to match indices dim");
5971 if (resVType.getShape() != maskVType.getShape())
5972 return emitOpError(
"expected result dim to match mask dim");
5973 if (resVType != getPassThruVectorType())
5974 return emitOpError(
"expected pass_thru of same type as result type");
5982Type GatherOp::getExpectedMaskType() {
5983 auto vecType = this->getIndexVectorType();
5984 return VectorType::get(vecType.getShape(),
5985 IntegerType::get(vecType.getContext(), 1),
5986 vecType.getScalableDims());
5989std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5994static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5995 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5996 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6002 DenseIntElementsAttr elements;
6007 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6011class GatherFolder final :
public OpRewritePattern<GatherOp> {
6014 LogicalResult matchAndRewrite(GatherOp gather,
6015 PatternRewriter &rewriter)
const override {
6020 rewriter.
replaceOp(gather, gather.getPassThru());
6025 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6031class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6034 LogicalResult matchAndRewrite(GatherOp op,
6035 PatternRewriter &rewriter)
const override {
6036 if (!isa<MemRefType>(op.getBase().getType()))
6039 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6043 op.getOffsets(), op.getMask(),
6050void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6051 MLIRContext *context) {
6052 results.
add<GatherFolder, FoldContiguousGather>(context);
6055FailureOr<std::optional<SmallVector<Value>>>
6056GatherOp::bubbleDownCasts(OpBuilder &builder) {
6065LogicalResult ScatterOp::verify() {
6066 VectorType indVType = getIndexVectorType();
6067 VectorType maskVType = getMaskVectorType();
6069 ShapedType baseType = getBaseType();
6071 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6072 return emitOpError(
"requires base to be a memref or ranked tensor type");
6074 if (valueVType.getElementType() != baseType.getElementType())
6075 return emitOpError(
"base and valueToStore element type should match");
6076 if (llvm::size(getOffsets()) != baseType.getRank())
6077 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6078 if (valueVType.getShape() != indVType.getShape())
6079 return emitOpError(
"expected valueToStore dim to match indices dim");
6080 if (valueVType.getShape() != maskVType.getShape())
6081 return emitOpError(
"expected valueToStore dim to match mask dim");
6085class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6088 LogicalResult matchAndRewrite(ScatterOp scatter,
6089 PatternRewriter &rewriter)
const override {
6099 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6105class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6108 LogicalResult matchAndRewrite(ScatterOp op,
6109 PatternRewriter &rewriter)
const override {
6110 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6114 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6120void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6121 MLIRContext *context) {
6122 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6125FailureOr<std::optional<SmallVector<Value>>>
6126ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6135LogicalResult ExpandLoadOp::verify() {
6136 VectorType maskVType = getMaskVectorType();
6137 VectorType passVType = getPassThruVectorType();
6141 if (resVType.getElementType() != memType.getElementType())
6142 return emitOpError(
"base and result element type should match");
6143 if (llvm::size(
getIndices()) != memType.getRank())
6144 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6145 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6146 return emitOpError(
"expected result dim to match mask dim");
6147 if (resVType != passVType)
6148 return emitOpError(
"expected pass_thru of same type as result type");
6153class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6156 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6157 PatternRewriter &rewriter)
const override {
6161 expand, expand.getType(), expand.getBase(), expand.getIndices());
6164 rewriter.
replaceOp(expand, expand.getPassThru());
6169 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6174void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6175 MLIRContext *context) {
6176 results.
add<ExpandLoadFolder>(context);
6179FailureOr<std::optional<SmallVector<Value>>>
6180ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6189LogicalResult CompressStoreOp::verify() {
6190 VectorType maskVType = getMaskVectorType();
6194 if (valueVType.getElementType() != memType.getElementType())
6195 return emitOpError(
"base and valueToStore element type should match");
6196 if (llvm::size(
getIndices()) != memType.getRank())
6197 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6198 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6199 return emitOpError(
"expected valueToStore dim to match mask dim");
6204class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6207 LogicalResult matchAndRewrite(CompressStoreOp compress,
6208 PatternRewriter &rewriter)
const override {
6212 compress, compress.getValueToStore(), compress.getBase(),
6213 compress.getIndices());
6221 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6226void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6227 MLIRContext *context) {
6228 results.
add<CompressStoreFolder>(context);
6231FailureOr<std::optional<SmallVector<Value>>>
6232CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6241void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6243 setResultRanges(getResult(), argRanges.front());
6246std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6247 return llvm::to_vector<4>(getResultVectorType().
getShape());
6250LogicalResult ShapeCastOp::verify() {
6252 VectorType sourceType = getSourceVectorType();
6253 VectorType resultType = getResultVectorType();
6256 if (sourceType.getElementType() != resultType.getElementType())
6257 return emitOpError(
"has different source and result element types");
6260 int64_t sourceNElms = sourceType.getNumElements();
6261 int64_t resultNElms = resultType.getNumElements();
6262 if (sourceNElms != resultNElms) {
6263 return emitOpError() <<
"has different number of elements at source ("
6264 << sourceNElms <<
") and result (" << resultNElms
6269 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6270 int64_t resultNScalableDims = resultType.getNumScalableDims();
6271 if (sourceNScalableDims != resultNScalableDims)
6272 return emitOpError() <<
"has different number of scalable dims at source ("
6273 << sourceNScalableDims <<
") and result ("
6274 << resultNScalableDims <<
")";
6283static bool isOrderPreserving(TransposeOp transpose) {
6284 ArrayRef<int64_t> permutation = transpose.getPermutation();
6285 VectorType sourceType = transpose.getSourceVectorType();
6286 ArrayRef<int64_t> inShape = sourceType.getShape();
6287 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6288 auto isNonScalableUnitDim = [&](int64_t dim) {
6289 return inShape[dim] == 1 && !inDimIsScalable[dim];
6291 int64_t current = 0;
6292 for (
auto p : permutation) {
6293 if (!isNonScalableUnitDim(p)) {
6303OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6305 VectorType resultType =
getType();
6308 if (getSource().
getType() == resultType)
6312 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6313 setOperand(precedingShapeCast.getSource());
6318 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6319 if (isOrderPreserving(transpose)) {
6320 setOperand(transpose.getVector());
6328 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6329 if (bcastOp.getSourceType() == resultType)
6330 return bcastOp.getSource();
6334 if (
auto denseAttr =
6335 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6336 return denseAttr.reshape(
getType());
6339 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6352static VectorType trimTrailingOneDims(VectorType oldType) {
6353 ArrayRef<int64_t> oldShape = oldType.getShape();
6354 ArrayRef<int64_t> newShape = oldShape;
6356 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6357 ArrayRef<bool> newScalableDims = oldScalableDims;
6359 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6360 newShape = newShape.drop_back(1);
6361 newScalableDims = newScalableDims.drop_back(1);
6366 if (newShape.empty()) {
6367 newShape = oldShape.take_back();
6368 newScalableDims = oldScalableDims.take_back();
6371 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6386class ShapeCastCreateMaskFolderTrailingOneDim final
6387 :
public OpRewritePattern<ShapeCastOp> {
6391 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6392 PatternRewriter &rewriter)
const override {
6393 Value shapeOpSrc = shapeOp->getOperand(0);
6394 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6395 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6396 if (!createMaskOp && !constantMaskOp)
6399 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6400 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6402 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6403 if (newVecType != shapeOpResTy)
6406 auto numDimsToDrop =
6407 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6414 auto maskOperands = createMaskOp.getOperands();
6415 auto numMaskOperands = maskOperands.size();
6418 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6420 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6421 if (!constant || (constant.value() != 1))
6424 SmallVector<Value> newMaskOperands =
6425 maskOperands.drop_back(numDimsToDrop);
6432 if (constantMaskOp) {
6433 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6434 auto numMaskOperands = maskDimSizes.size();
6437 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6439 if (maskDimSizes[i] != 1)
6443 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6457class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6461 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6462 PatternRewriter &rewriter)
const override {
6464 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6468 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6469 bool srcIsScalar = !srcVectorType;
6477 if (srcVectorType) {
6478 if (srcVectorType.getNumElements() ==
6479 shapeCastOp.getResultVectorType().getNumElements()) {
6481 shapeCastOp, shapeCastOp.getResultVectorType(),
6482 broadcastOp.getSource());
6493 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6495 BroadcastableToResult::Success) {
6497 shapeCastOp, dstVectorType, broadcastOp.getSource());
6506void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6507 MLIRContext *context) {
6509 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6517LogicalResult BitCastOp::verify() {
6518 auto sourceVectorType = getSourceVectorType();
6519 auto resultVectorType = getResultVectorType();
6521 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6522 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6523 return emitOpError(
"dimension size mismatch at: ") << i;
6526 DataLayout dataLayout = DataLayout::closest(*
this);
6527 auto sourceElementBits =
6529 auto resultElementBits =
6532 if (sourceVectorType.getRank() == 0) {
6533 if (sourceElementBits != resultElementBits)
6534 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6535 "types must be equal");
6536 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6537 resultElementBits * resultVectorType.getShape().back()) {
6539 "source/result bitwidth of the minor 1-D vectors must be equal");
6545OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6551 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6552 if (getResult().
getType() == otherOp.getSource().getType())
6553 return otherOp.getSource();
6555 setOperand(otherOp.getSource());
6559 Attribute sourceConstant = adaptor.getSource();
6560 if (!sourceConstant)
6563 Type srcElemType = getSourceVectorType().getElementType();
6564 Type dstElemType = getResultVectorType().getElementType();
6566 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6567 if (floatPack.isSplat()) {
6568 auto splat = floatPack.getSplatValue<FloatAttr>();
6571 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6572 uint32_t bits =
static_cast<uint32_t
>(
6573 splat.getValue().bitcastToAPInt().getZExtValue());
6575 bits = (bits << 16) | (bits & 0xffff);
6576 APInt intBits(32, bits);
6577 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6583 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6584 if (intPack.isSplat()) {
6585 auto splat = intPack.getSplatValue<IntegerAttr>();
6587 if (llvm::isa<IntegerType>(dstElemType)) {
6592 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6593 APInt intBits = splat.getValue().zext(dstBitWidth);
6596 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6597 intBits = (intBits << srcBitWidth) | intBits;
6611static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6612 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6613 SmallVector<int64_t, 8> res(memRefType.getShape());
6615 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6621void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
6623 result.addOperands(source);
6624 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6625 VectorType vectorType =
6626 VectorType::get(extractShape(memRefType),
6628 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6629 memRefType.getMemorySpace()));
6632LogicalResult TypeCastOp::verify() {
6633 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6634 if (!canonicalType.getLayout().isIdentity())
6635 return emitOpError(
"expects operand to be a memref with identity layout");
6636 if (!getResultMemRefType().getLayout().isIdentity())
6637 return emitOpError(
"expects result to be a memref with identity layout");
6638 if (getResultMemRefType().getMemorySpace() !=
6640 return emitOpError(
"expects result in same memory space");
6643 auto resultType = getResultMemRefType();
6647 "expects result and operand with same underlying scalar type: ")
6649 if (extractShape(sourceType) != extractShape(resultType))
6651 "expects concatenated result and operand shapes to be equal: ")
6660void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
6661 Value vector, ArrayRef<int64_t> permutation) {
6662 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6663 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6664 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6665 for (
unsigned i = 0; i < permutation.size(); ++i) {
6666 transposedShape[i] = vt.getShape()[permutation[i]];
6667 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6670 result.addOperands(vector);
6671 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6672 transposedScalableDims));
6673 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
6677OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6680 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6681 return splat.reshape(getResultVectorType());
6684 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6698 if (getSourceVectorType() == getResultVectorType() &&
6699 isOrderPreserving(*
this))
6705LogicalResult vector::TransposeOp::verify() {
6706 VectorType vectorType = getSourceVectorType();
6707 VectorType resultType = getResultVectorType();
6708 int64_t rank = resultType.getRank();
6709 if (vectorType.getRank() != rank)
6710 return emitOpError(
"vector result rank mismatch: ") << rank;
6712 ArrayRef<int64_t> perm = getPermutation();
6713 int64_t size = perm.size();
6715 return emitOpError(
"transposition length mismatch: ") << size;
6716 SmallVector<bool, 8> seen(rank,
false);
6717 for (
const auto &ta : llvm::enumerate(perm)) {
6718 if (ta.value() < 0 || ta.value() >= rank)
6719 return emitOpError(
"transposition index out of range: ") << ta.value();
6720 if (seen[ta.value()])
6721 return emitOpError(
"duplicate position index: ") << ta.value();
6722 seen[ta.value()] =
true;
6723 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6724 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6729std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6730 return llvm::to_vector<4>(getResultVectorType().
getShape());
6733void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6735 setResultRanges(getResult(), argRanges.front());
6741class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6745 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6746 PatternRewriter &rewriter)
const override {
6748 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6749 ArrayRef<int64_t> permutation2) {
6750 SmallVector<int64_t, 4>
result;
6751 for (
auto index : permutation2)
6752 result.push_back(permutation1[index]);
6757 vector::TransposeOp parentTransposeOp =
6758 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6759 if (!parentTransposeOp)
6762 SmallVector<int64_t, 4> permutation = composePermutations(
6763 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6766 transposeOp, transposeOp.getResult().
getType(),
6767 parentTransposeOp.getVector(), permutation);
6773class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
6777 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6778 PatternRewriter &rewriter)
const override {
6779 Value splat = getScalarSplatSource(transposeOp.getVector());
6784 transposeOp, transposeOp.getResultVectorType(), splat);
6790class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6794 LogicalResult matchAndRewrite(TransposeOp transpOp,
6795 PatternRewriter &rewriter)
const override {
6796 Value transposeSrc = transpOp.getVector();
6797 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6798 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6799 if (!createMaskOp && !constantMaskOp)
6804 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6807 auto maskOperands = createMaskOp.getOperands();
6808 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6812 transpOp, transpOp.getResultVectorType(), newOperands);
6817 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6821 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6827class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6831 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6832 PatternRewriter &rewriter)
const override {
6834 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6837 if (!isOrderPreserving(transposeOp))
6840 VectorType resultType = transposeOp.getType();
6847 shapeCastOp.getSource());
6866class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
6869 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6870 PatternRewriter &rewriter)
const override {
6871 auto fromElementsOp =
6872 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6873 if (!fromElementsOp)
6876 VectorType srcTy = fromElementsOp.getDest().getType();
6877 VectorType dstTy = transposeOp.getType();
6879 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6880 int64_t rank = srcTy.getRank();
6883 SmallVector<int64_t> inversePerm(rank, 0);
6884 for (int64_t i = 0; i < rank; ++i)
6885 inversePerm[permutation[i]] = i;
6887 ArrayRef<int64_t> srcShape = srcTy.getShape();
6888 ArrayRef<int64_t> dstShape = dstTy.getShape();
6889 SmallVector<int64_t> srcIdx(rank, 0);
6890 SmallVector<int64_t> dstIdx(rank, 0);
6894 auto elementsOld = fromElementsOp.getElements();
6895 SmallVector<Value> elementsNew;
6896 int64_t dstNumElements = dstTy.getNumElements();
6897 elementsNew.reserve(dstNumElements);
6901 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
6905 for (int64_t j = 0; j < rank; ++j)
6906 srcIdx[j] = dstIdx[inversePerm[j]];
6908 int64_t srcLin =
linearize(srcIdx, srcStrides);
6910 elementsNew.push_back(elementsOld[srcLin]);
6944class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
6947 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6948 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6950 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6951 PatternRewriter &rewriter)
const override {
6957 "not preceded by a broadcast");
6960 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
6961 VectorType outputType = transpose.getResultVectorType();
6964 bool inputIsScalar = !inputType;
6965 if (inputIsScalar) {
6971 ArrayRef<int64_t> permutation = transpose.getPermutation();
6972 ArrayRef<int64_t> inputShape = inputType.getShape();
6973 int64_t inputRank = inputType.getRank();
6974 int64_t outputRank = transpose.getType().getRank();
6975 int64_t deltaRank = outputRank - inputRank;
6978 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6979 bool notOne = inputShape[inputIndex] != 1;
6980 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6981 bool groupEndFound = notOne || prevNotOne;
6982 if (groupEndFound) {
6983 int high = inputIndex + deltaRank;
6987 for (
int i = low; i < high; ++i) {
6988 if (permutation[i] < low || permutation[i] >= high) {
6990 transpose,
"permutation not local to group");
7004 vector::BroadcastableToResult::Success &&
7005 "not broadcastable directly to transpose output");
7016void vector::TransposeOp::getCanonicalizationPatterns(
7017 RewritePatternSet &results, MLIRContext *context) {
7018 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7019 FoldTransposeSplat, FoldTransposeFromElements,
7020 FoldTransposeBroadcast>(context);
7027void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7029 assert(kind == ConstantMaskKind::AllTrue ||
7030 kind == ConstantMaskKind::AllFalse);
7031 build(builder,
result, type,
7032 kind == ConstantMaskKind::AllTrue
7034 : SmallVector<int64_t>(type.getRank(), 0));
7037LogicalResult ConstantMaskOp::verify() {
7038 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7040 if (resultType.getRank() == 0) {
7041 if (getMaskDimSizes().size() != 1)
7042 return emitError(
"array attr must have length 1 for 0-D vectors");
7043 auto dim = getMaskDimSizes()[0];
7044 if (dim != 0 && dim != 1)
7045 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7050 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7052 "must specify array attr of size equal vector result rank");
7055 auto resultShape = resultType.getShape();
7056 auto resultScalableDims = resultType.getScalableDims();
7057 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7058 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7059 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7061 "array attr of size out of bounds of vector result dimension size");
7062 if (resultScalableDims[index] && maskDimSize != 0 &&
7063 maskDimSize != resultShape[index])
7065 "only supports 'none set' or 'all set' scalable dimensions");
7069 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7070 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7071 if (anyZeros && !allZeros)
7072 return emitOpError(
"expected all mask dim sizes to be zeros, "
7073 "as a result of conjunction with zero mask dim");
7077bool ConstantMaskOp::isAllOnesMask() {
7080 if (resultType.getRank() == 0) {
7081 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7082 return getMaskDimSizes()[0] == 1;
7084 for (
const auto [resultSize, maskDimSize] :
7085 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7086 if (maskDimSize < resultSize)
7092OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7093 ArrayRef<int64_t> bounds = getMaskDimSizes();
7096 auto createBoolSplat = [&](
bool x) {
7102 if (vectorSizes.empty()) {
7103 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7104 return createBoolSplat(bounds[0] == 1);
7107 if (bounds == vectorSizes)
7108 return createBoolSplat(
true);
7109 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7110 return createBoolSplat(
false);
7111 return OpFoldResult();
7118void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7120 ArrayRef<OpFoldResult> mixedOperands) {
7121 SmallVector<Value> operands =
7123 build(builder,
result, type, operands);
7126LogicalResult CreateMaskOp::verify() {
7127 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7129 if (vectorType.getRank() == 0) {
7130 if (getNumOperands() != 1)
7132 "must specify exactly one operand for 0-D create_mask");
7133 }
else if (getNumOperands() !=
7134 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7136 "must specify an operand for each result vector dimension");
7166class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7170 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7171 PatternRewriter &rewriter)
const override {
7172 VectorType maskType = createMaskOp.getVectorType();
7173 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7174 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7177 constexpr std::array<int64_t, 1> rankZeroShape{1};
7178 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7179 if (maskType.getRank() == 0) {
7180 maskTypeDimSizes = rankZeroShape;
7181 maskTypeDimScalableFlags = rankZeroScalableDims;
7186 SmallVector<int64_t, 4> constantDims;
7187 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7192 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7194 constantDims.push_back(*intSize);
7198 if (vscaleMultiplier < maskTypeDimSizes[i])
7200 constantDims.push_back(*vscaleMultiplier);
7207 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7208 value = std::clamp<int64_t>(value, 0, maskDimSize);
7211 if (llvm::is_contained(constantDims, 0))
7212 constantDims.assign(constantDims.size(), 0);
7223void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7224 MLIRContext *context) {
7225 results.
add<CreateMaskFolder>(context);
7233 OpBuilder &builder, OperationState &
result, Value mask,
7234 Operation *maskableOp,
7235 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7236 assert(maskRegionBuilder &&
7237 "builder callback for 'maskRegion' must be present");
7239 result.addOperands(mask);
7240 OpBuilder::InsertionGuard guard(builder);
7241 Region *maskRegion =
result.addRegion();
7243 maskRegionBuilder(builder, maskableOp);
7248 Value mask, Operation *maskableOp,
7249 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7250 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7256 Value mask, Value passthru, Operation *maskableOp,
7257 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7258 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7260 result.addOperands(passthru);
7261 result.addTypes(resultTypes);
7264ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7266 result.regions.reserve(1);
7267 Region &maskRegion = *
result.addRegion();
7272 OpAsmParser::UnresolvedOperand mask;
7277 OpAsmParser::UnresolvedOperand passthru;
7279 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7286 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7297 SmallVector<Type> resultTypes;
7300 result.types.append(resultTypes);
7306 if (parsePassthru.succeeded()) {
7307 if (resultTypes.empty())
7310 "expects a result if passthru operand is provided");
7319void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7320 p <<
" " << getMask();
7322 p <<
", " << getPassthru();
7326 Block *singleBlock = &getMaskRegion().getBlocks().front();
7333 p <<
" : " << getMask().getType();
7334 if (getNumResults() > 0)
7335 p <<
" -> " << getResultTypes();
7338void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7341 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7342 MaskOp>::ensureTerminator(region, builder, loc);
7348 if (isa<vector::YieldOp>(block.
back()))
7356 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7357 MaskOp>::ensureTerminator(region, builder, loc);
7363 Operation *maskedOp = &block.
front();
7364 opBuilder.setInsertionPointToEnd(&block);
7365 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7368LogicalResult MaskOp::verify() {
7370 Block &block = getMaskRegion().getBlocks().
front();
7372 return emitOpError(
"expects a terminator within the mask region");
7375 if (numMaskRegionOps > 2)
7376 return emitOpError(
"expects only one operation to mask");
7379 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7381 return emitOpError(
"expects a terminator within the mask region");
7383 if (terminator->getNumOperands() != getNumResults())
7385 "expects number of results to match mask region yielded values");
7388 if (numMaskRegionOps == 1)
7391 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7393 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7397 return emitOpError(
"expects number of results to match maskable operation "
7398 "number of results");
7400 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7401 return emitOpError(
"expects all the results from the MaskableOpInterface "
7402 "to match all the values returned by the terminator");
7404 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7406 "expects result type to match maskable operation result type");
7409 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7410 return emitOpError(
"multiple vector results not supported");
7413 Type expectedMaskType = maskableOp.getExpectedMaskType();
7414 if (getMask().
getType() != expectedMaskType)
7416 << expectedMaskType <<
" mask for the maskable operation";
7419 Value passthru = getPassthru();
7421 if (!maskableOp.supportsPassthru())
7423 "doesn't expect a passthru argument for this maskable operation");
7426 return emitOpError(
"expects result when passthru argument is provided");
7429 return emitOpError(
"expects passthru type to match result type");
7449static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7450 SmallVectorImpl<OpFoldResult> &results) {
7451 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7454 Block *block = maskOp.getMaskBlock();
7455 auto terminator = cast<vector::YieldOp>(block->
front());
7456 if (terminator.getNumOperands() == 0) {
7462 llvm::append_range(results, terminator.getOperands());
7466LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7467 SmallVectorImpl<OpFoldResult> &results) {
7468 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7476 Operation *maskableOp = getMaskableOp();
7480 llvm::append_range(results, maskableOp->
getResults());
7496class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
7499 LogicalResult matchAndRewrite(MaskOp maskOp,
7500 PatternRewriter &rewriter)
const override {
7501 if (!maskOp.isEmpty())
7504 if (!maskOp.hasPassthru())
7507 Block *block = maskOp.getMaskBlock();
7508 auto terminator = cast<vector::YieldOp>(block->
front());
7509 assert(terminator.getNumOperands() == 1 &&
7510 "expected one result when passthru is provided");
7513 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7514 terminator.getOperand(0), maskOp.getPassthru());
7520void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7521 MLIRContext *context) {
7522 results.
add<CanonializeEmptyMaskOp>(context);
7528Operation *MaskOp::getMaskableOp() {
7529 Block *block = getMaskBlock();
7533 return &block->
front();
7537bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
7543LogicalResult ScanOp::verify() {
7544 VectorType srcType = getSourceType();
7545 VectorType initialType = getInitialValueType();
7547 int64_t srcRank = srcType.getRank();
7548 int64_t reductionDim = getReductionDim();
7549 if (reductionDim >= srcRank)
7551 << reductionDim <<
" has to be less than " << srcRank;
7554 int64_t initialValueRank = initialType.getRank();
7555 if (initialValueRank != srcRank - 1)
7557 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7560 ArrayRef<int64_t> srcShape = srcType.getShape();
7561 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7562 SmallVector<int64_t> expectedShape;
7563 for (
int i = 0; i < srcRank; i++) {
7564 if (i != reductionDim)
7565 expectedShape.push_back(srcShape[i]);
7567 if (!llvm::equal(initialValueShapes, expectedShape)) {
7568 return emitOpError(
"incompatible input/initial value shapes");
7572 Type eltType = getDestType().getElementType();
7575 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7582 RewritePatternSet &
patterns, PatternBenefit benefit) {
7584 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7585 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7586 StridedSliceConstantMaskFolder, TransposeFolder>(
7591 CombiningKind kind, Value v1, Value acc,
7592 arith::FastMathFlagsAttr fastmath,
7599 case CombiningKind::ADD:
7601 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
7602 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7603 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7605 llvm_unreachable(
"invalid value types for ADD reduction");
7607 case CombiningKind::AND:
7609 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
7611 case CombiningKind::MAXNUMF:
7612 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7613 "expected float values");
7614 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7616 case CombiningKind::MAXIMUMF:
7617 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7618 "expected float values");
7619 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7621 case CombiningKind::MINNUMF:
7622 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7623 "expected float values");
7624 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7626 case CombiningKind::MINIMUMF:
7627 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7628 "expected float values");
7629 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7631 case CombiningKind::MAXSI:
7633 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7635 case CombiningKind::MINSI:
7637 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7639 case CombiningKind::MAXUI:
7641 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7643 case CombiningKind::MINUI:
7645 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7647 case CombiningKind::MUL:
7649 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
7650 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7651 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7653 llvm_unreachable(
"invalid value types for MUL reduction");
7655 case CombiningKind::OR:
7657 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
7659 case CombiningKind::XOR:
7661 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7665 assert(
result &&
"unknown CombiningKind");
7673void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7675 auto resultType = cast<VectorType>(
getType());
7676 if (resultType.isScalable()) {
7680 APInt zero(bitwidth, 0);
7681 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7682 ConstantIntRanges
result = {zero, high, zero, high};
7683 setResultRanges(getResult(),
result);
7713struct StepCompareFolder :
public OpRewritePattern<StepOp> {
7716 LogicalResult matchAndRewrite(StepOp stepOp,
7717 PatternRewriter &rewriter)
const override {
7718 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7720 for (OpOperand &use : stepOp.getResult().getUses()) {
7721 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7726 const unsigned stepOperandNumber = use.getOperandNumber();
7727 if (stepOperandNumber != 0)
7731 unsigned constOperandNumber = 1;
7732 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7733 std::optional<int64_t> maybeConstValue =
7735 if (!maybeConstValue.has_value())
7738 int64_t constValue = maybeConstValue.value();
7739 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7741 auto maybeSplat = [&]() -> std::optional<bool> {
7743 if ((pred == arith::CmpIPredicate::ult ||
7744 pred == arith::CmpIPredicate::uge) &&
7745 stepSize <= constValue)
7746 return pred == arith::CmpIPredicate::ult;
7749 if ((pred == arith::CmpIPredicate::ule ||
7750 pred == arith::CmpIPredicate::ugt) &&
7751 stepSize - 1 <= constValue) {
7752 return pred == arith::CmpIPredicate::ule;
7756 if ((pred == arith::CmpIPredicate::eq ||
7757 pred == arith::CmpIPredicate::ne) &&
7758 stepSize <= constValue)
7759 return pred == arith::CmpIPredicate::ne;
7761 return std::nullopt;
7764 if (!maybeSplat.has_value())
7769 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7774 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7786void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7787 MLIRContext *context) {
7788 results.
add<StepCompareFolder>(context);
7798 Operation *maskableOp) {
7799 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7811 Operation *maskableOp, Value mask,
7816 return MaskOp::create(builder, maskableOp->
getLoc(),
7819 return MaskOp::create(builder, maskableOp->
getLoc(),
7832 Value newValue, Value passthru) {
7836 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
7837 mask, newValue, passthru);
7844#define GET_ATTRDEF_CLASSES
7845#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7847#define GET_OP_CLASSES
7848#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)