41#include "llvm/ADT/ArrayRef.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Casting.h"
52#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
54#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
75 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
77 for (
bool b : denseElts.getValues<
bool>())
80 else if (!
b && val <= 0)
94 auto shape = m.getType().getShape();
97 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
98 if (maskIdx < dimSize)
111 auto maskOperands = m.getOperands();
112 for (
Value operand : maskOperands) {
113 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
128 vector::YieldOp::create(builder, loc);
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
138 case CombiningKind::MINUI:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
190 VectorType vectorType) {
193 if (shapedType.getRank() == 0 &&
199 shapedType.getRank(),
201 shapedType.getContext());
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
255 if (transferA.getVectorType() != transferB.getVectorType())
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
264 if (i < rankOffset) {
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
272 if (testDynamicValueUsingBounds) {
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
295 if (testDynamicValueUsingBounds) {
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (std::abs(computeDelta.value()) >= vectorDim)
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
321 testDynamicValueUsingBounds);
331 for (
auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
334 if (posInDim < dimSize + offsetInDim)
338 posInDim = offsetInDim;
348 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
350 assert(constOp &&
"Unexpected non-constant index");
351 return constOp.value();
361 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
373 llvm::transform(foldResults, std::back_inserter(values),
375 if (
auto attr = dyn_cast<Attribute>(foldResult))
377 builder, loc, cast<IntegerAttr>(attr).getInt())
380 return cast<Value>(foldResult);
393 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
395 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
405 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
407 if (intAttr.getType() != expectedType)
408 return IntegerAttr::get(expectedType, intAttr.getInt());
414 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415 auto intType = dyn_cast<IntegerType>(expectedType);
419 APFloat floatVal = floatAttr.getValue();
420 APInt intVal = floatVal.bitcastToAPInt();
421 return IntegerAttr::get(expectedType, intVal);
460struct VectorInlinerInterface :
public DialectInlinerInterface {
461 using DialectInlinerInterface::DialectInlinerInterface;
470void VectorDialect::initialize() {
472#define GET_ATTRDEF_LIST
473#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
478#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
481 addInterfaces<VectorInlinerInterface>();
483 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
486 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
488 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
498 if (isa<ub::PoisonAttrInterface>(value))
501 return arith::ConstantOp::materialize(builder, value, type, loc);
517void vector::MultiDimReductionOp::build(
OpBuilder &builder,
520 CombiningKind kind) {
522 for (
const auto &en : llvm::enumerate(reductionMask))
524 reductionDims.push_back(en.index());
525 build(builder,
result, kind, source,
acc, reductionDims);
528OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
530 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
535std::optional<SmallVector<int64_t, 4>>
536MultiDimReductionOp::getShapeForUnroll() {
537 return llvm::to_vector<4>(getSourceVectorType().
getShape());
540LogicalResult MultiDimReductionOp::verify() {
543 Type inferredReturnType;
544 auto sourceScalableDims = getSourceVectorType().getScalableDims();
545 for (
auto [dimIdx, dimSize] :
546 llvm::enumerate(getSourceVectorType().
getShape()))
547 if (!llvm::any_of(getReductionDims(),
548 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
549 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
551 targetShape.push_back(dimSize);
552 scalableDims.push_back(sourceScalableDims[dimIdx]);
555 if (targetShape.empty())
556 inferredReturnType = getSourceVectorType().getElementType();
558 inferredReturnType = VectorType::get(
559 targetShape, getSourceVectorType().
getElementType(), scalableDims);
560 if (
getType() != inferredReturnType)
562 <<
" is incompatible with source type "
563 << getSourceVectorType();
569Type MultiDimReductionOp::getExpectedMaskType() {
570 auto vecType = getSourceVectorType();
571 return VectorType::get(vecType.getShape(),
572 IntegerType::get(vecType.getContext(), 1),
573 vecType.getScalableDims());
582struct ElideUnitDimsInMultiDimReduction
586 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
587 PatternRewriter &rewriter)
const override {
588 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
589 for (
const auto &dim :
enumerate(shape)) {
590 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
595 OpBuilder::InsertionGuard guard(rewriter);
598 if (reductionOp.isMasked()) {
600 rootOp = reductionOp.getMaskingOp();
601 mask = reductionOp.getMaskingOp().getMask();
603 rootOp = reductionOp;
606 Location loc = reductionOp.getLoc();
607 Value acc = reductionOp.getAcc();
609 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
611 VectorType newMaskType =
612 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
613 dstVecType.getScalableDims());
614 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
616 cast = vector::ShapeCastOp::create(
617 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
622 mask = vector::ExtractOp::create(rewriter, loc, mask);
623 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
628 cast,
nullptr, mask);
635void MultiDimReductionOp::getCanonicalizationPatterns(
637 results.
add<ElideUnitDimsInMultiDimReduction>(context);
646 arith::FastMathFlags fastMathFlags) {
652 arith::FastMathFlags fastMathFlags) {
654 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
658LogicalResult ReductionOp::verify() {
660 int64_t rank = getSourceVectorType().getRank();
662 return emitOpError(
"unsupported reduction rank: ") << rank;
665 Type eltType = getDest().getType();
668 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
677Type ReductionOp::getExpectedMaskType() {
678 auto vecType = getSourceVectorType();
679 return VectorType::get(vecType.getShape(),
680 IntegerType::get(vecType.getContext(), 1),
681 vecType.getScalableDims());
688 case arith::AtomicRMWKind::addf:
689 case arith::AtomicRMWKind::addi:
690 return vector::ReductionOp::create(builder,
vector.getLoc(),
691 CombiningKind::ADD,
vector);
692 case arith::AtomicRMWKind::mulf:
693 case arith::AtomicRMWKind::muli:
694 return vector::ReductionOp::create(builder,
vector.getLoc(),
695 CombiningKind::MUL,
vector);
696 case arith::AtomicRMWKind::minimumf:
697 return vector::ReductionOp::create(builder,
vector.getLoc(),
698 CombiningKind::MINIMUMF,
vector);
699 case arith::AtomicRMWKind::mins:
700 return vector::ReductionOp::create(builder,
vector.getLoc(),
701 CombiningKind::MINSI,
vector);
702 case arith::AtomicRMWKind::minu:
703 return vector::ReductionOp::create(builder,
vector.getLoc(),
704 CombiningKind::MINUI,
vector);
705 case arith::AtomicRMWKind::maximumf:
706 return vector::ReductionOp::create(builder,
vector.getLoc(),
707 CombiningKind::MAXIMUMF,
vector);
708 case arith::AtomicRMWKind::maxs:
709 return vector::ReductionOp::create(builder,
vector.getLoc(),
710 CombiningKind::MAXSI,
vector);
711 case arith::AtomicRMWKind::maxu:
712 return vector::ReductionOp::create(builder,
vector.getLoc(),
713 CombiningKind::MAXUI,
vector);
714 case arith::AtomicRMWKind::andi:
715 return vector::ReductionOp::create(builder,
vector.getLoc(),
716 CombiningKind::AND,
vector);
717 case arith::AtomicRMWKind::ori:
718 return vector::ReductionOp::create(builder,
vector.getLoc(),
719 CombiningKind::OR,
vector);
720 case arith::AtomicRMWKind::minnumf:
721 return vector::ReductionOp::create(builder,
vector.getLoc(),
722 CombiningKind::MINNUMF,
vector);
723 case arith::AtomicRMWKind::maxnumf:
724 return vector::ReductionOp::create(builder,
vector.getLoc(),
725 CombiningKind::MAXNUMF,
vector);
726 case arith::AtomicRMWKind::xori:
727 return vector::ReductionOp::create(builder,
vector.getLoc(),
728 CombiningKind::XOR,
vector);
736std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
737 return llvm::to_vector<4>(getSourceVectorType().
getShape());
744 LogicalResult matchAndRewrite(ReductionOp reductionOp,
749 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
752 if (maskableOp.isMasked()) {
754 rootOp = maskableOp.getMaskingOp();
755 mask = maskableOp.getMaskingOp().getMask();
757 rootOp = reductionOp;
760 auto vectorType = reductionOp.getSourceVectorType();
761 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
764 Location loc = reductionOp.getLoc();
766 mask = ExtractOp::create(rewriter, loc, mask);
767 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
769 if (Value acc = reductionOp.getAcc())
772 reductionOp.getFastmathAttr(), mask);
782 results.
add<ElideSingleElementReduction>(context);
796 getIndexingMapsAttrName(
result.name),
800 getIteratorTypesAttrName(
result.name),
803 return IteratorTypeAttr::get(builder.getContext(), t);
812 ContractionOp::getDefaultKind());
818 ArrayAttr iteratorTypes, CombiningKind kind) {
821 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
822 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
824 CombiningKindAttr::get(builder.
getContext(), kind));
835 DictionaryAttr dictAttr;
849 result.attributes.append(dictAttr.getValue().begin(),
850 dictAttr.getValue().end());
856 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
857 result.attributes.get(getIteratorTypesAttrName(
result.name)));
858 if (!iteratorTypes) {
860 <<
"expected " << getIteratorTypesAttrName(
result.name)
861 <<
" array attribute";
866 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
867 auto maybeIteratorType = symbolizeIteratorType(s);
868 if (!maybeIteratorType.has_value())
869 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
871 iteratorTypeAttrs.push_back(
872 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
874 result.attributes.set(getIteratorTypesAttrName(
result.name),
877 if (!
result.attributes.get(getKindAttrName(
result.name))) {
879 getKindAttrName(
result.name),
880 CombiningKindAttr::get(
result.getContext(),
881 ContractionOp::getDefaultKind()));
883 if (masksInfo.empty())
885 if (masksInfo.size() != 2)
887 "expected zero or exactly 2 vector mask operands");
888 auto lhsType = llvm::cast<VectorType>(types[0]);
889 auto rhsType = llvm::cast<VectorType>(types[1]);
891 std::array<VectorType, 2> maskTypes = {
901 auto attrNames = getTraitAttrNames();
903 traitAttrsSet.insert_range(attrNames);
905 for (
auto attr : (*this)->getAttrs()) {
906 if (attr.getName() == getIteratorTypesAttrName()) {
908 llvm::cast<ArrayAttr>(attr.getValue())
909 .getAsValueRange<IteratorTypeAttr, IteratorType>();
915 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
916 return StringAttr::get(
getContext(), stringifyIteratorType(t));
919 attrs.emplace_back(getIteratorTypesAttrName(),
920 ArrayAttr::get(
getContext(), iteratorTypeNames));
921 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
922 attrs.push_back(attr);
925 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
926 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
927 p << getRhs() <<
", " << getAcc();
930 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
935 const std::vector<std::pair<int64_t, int64_t>> &map) {
936 for (
auto &dimPair : map) {
937 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
938 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
939 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
946 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
948 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
949 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
952 for (
auto &dimPair : contractingDimMap) {
953 lhsContractingDimSet.insert(dimPair.first);
954 rhsContractingDimSet.insert(dimPair.second);
957 llvm::make_second_range(batchDimMap));
961 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
962 if (lhsContractingDimSet.count(i) > 0)
964 expectedResultDims.push_back(lhsType.getDimSize(i));
968 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
969 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
971 expectedResultDims.push_back(rhsType.getDimSize(i));
975 if (expectedResultDims.empty()) {
977 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
978 return op.emitOpError(
"invalid accumulator/result vector shape");
981 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
982 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
983 if (!resVectorType || !accVectorType)
984 return op.emitOpError(
"invalid accumulator/result vector shape");
990 AffineMap lhsMap = op.getIndexingMapsArray()[0];
991 AffineMap rhsMap = op.getIndexingMapsArray()[1];
993 return op.emitOpError(
994 "expected all dimensions to be either a LHS or a RHS dimension");
997 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
998 VectorType v = pair.first;
999 auto map = pair.second;
1000 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1001 unsigned pos = map.getDimPosition(idx);
1006 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1007 return op.emitOpError(
"expected all dimensions to get an extent as "
1008 "either a LHS or a RHS dimension");
1010 AffineMap resMap = op.getIndexingMapsArray()[2];
1015 assert(llvm::all_of(expectedMap.
getResults(),
1016 llvm::IsaPred<AffineConstantExpr>) &&
1017 "expected constant extent along all dimensions.");
1019 auto expectedShape = llvm::to_vector<4>(
1021 return cast<AffineConstantExpr>(e).getValue();
1024 VectorType::get(expectedShape, resVectorType.getElementType(),
1025 resVectorType.getScalableDims());
1026 if (resVectorType != expected || accVectorType != expected)
1027 return op.emitOpError(
1028 "invalid accumulator/result vector shape, expected: ")
1034LogicalResult ContractionOp::verify() {
1035 VectorType lhsType = getLhsType();
1036 VectorType rhsType = getRhsType();
1037 Type accType = getAccType();
1038 Type resType = getResultType();
1040 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1041 if (!lhsType.getElementType().isSignlessInteger())
1042 return emitOpError(
"only supports signless integer types");
1046 if (getIndexingMapsArray().size() != 3)
1047 return emitOpError(
"expected an indexing map for each vector operand");
1052 unsigned numIterators = getIteratorTypes().getValue().size();
1053 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1054 auto index = it.index();
1055 auto map = it.value();
1056 if (map.getNumSymbols() != 0)
1058 <<
index <<
" to have no symbols";
1059 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1060 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1063 if (map.getNumDims() != numIterators)
1065 <<
index <<
" to have " << numIterators <<
" number of inputs";
1066 if (map.getNumResults() != rank)
1068 <<
index <<
" to have " << rank <<
" number of outputs";
1069 if (!map.isProjectedPermutation())
1071 <<
index <<
" to be a projected permutation of its inputs";
1074 auto contractingDimMap = getContractingDimMap();
1075 auto batchDimMap = getBatchDimMap();
1078 if (contractingDimMap.empty())
1079 return emitOpError(
"expected at least one contracting dimension pair");
1082 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1083 return emitOpError(
"invalid contracting dimension map");
1087 return emitOpError(
"invalid batch dimension map");
1091 contractingDimMap, batchDimMap)))
1095 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1096 auto elementType = vectorType ? vectorType.getElementType() : resType;
1098 return emitOpError(
"unsupported contraction type");
1101 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1108Type ContractionOp::getExpectedMaskType() {
1109 auto indexingMaps = this->getIndexingMapsArray();
1112 VectorType lhsType = this->getLhsType();
1113 VectorType rhsType = this->getRhsType();
1115 unsigned numVecDims = lhsIdxMap.
getNumDims();
1121 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1124 lhsType.getScalableDims()[dimIdx];
1126 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1129 rhsType.getScalableDims()[dimIdx];
1132 assert(ShapedType::isStaticShape(maskShape) &&
1133 "Mask shape couldn't be computed");
1135 return VectorType::get(maskShape,
1136 IntegerType::get(lhsType.getContext(), 1),
1137 maskShapeScalableDims);
1142 getIteratorTypesAttrName(), getKindAttrName()};
1152static std::vector<std::pair<int64_t, int64_t>>
1154 IteratorType targetIteratorType,
MLIRContext *context) {
1155 std::vector<std::pair<int64_t, int64_t>> dimMap;
1156 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1157 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1158 if (iteratorType != targetIteratorType)
1164 if (lhsDim >= 0 && rhsDim >= 0)
1165 dimMap.emplace_back(lhsDim, rhsDim);
1170void ContractionOp::getIterationBounds(
1172 auto lhsShape = getLhsType().getShape();
1173 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1175 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1178 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1179 if (iteratorType == IteratorType::reduction) {
1182 assert(lhsDimIndex >= 0);
1183 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1188 assert(resDimIndex >= 0);
1189 assert(resVectorType !=
nullptr);
1190 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1194void ContractionOp::getIterationIndexMap(
1196 unsigned numMaps = getIndexingMapsArray().size();
1197 iterationIndexMap.resize(numMaps);
1198 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1199 auto index = it.index();
1200 auto map = it.value();
1201 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1202 auto dim = cast<AffineDimExpr>(map.getResult(i));
1203 iterationIndexMap[
index][dim.getPosition()] = i;
1208std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1210 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1214std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1216 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1220std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1222 getIterationBounds(
shape);
1244template <
typename AddOpType>
1250 auto canonicalize = [&](
Value maybeContraction,
1251 Value otherOperand) -> vector::ContractionOp {
1252 vector::ContractionOp contractionOp =
1253 dyn_cast_or_null<vector::ContractionOp>(
1256 return vector::ContractionOp();
1257 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1258 contractionOp.getAcc().getDefiningOp())) {
1259 if (maybeZero.getValue() ==
1260 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1262 bvm.
map(contractionOp.getAcc(), otherOperand);
1263 auto newContraction =
1264 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1265 rewriter.
replaceOp(addOp, newContraction.getResult());
1266 return newContraction;
1269 return vector::ContractionOp();
1272 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1273 vector::ContractionOp
contract = canonicalize(a,
b);
1298 setResultRanges(getResult(), argRanges.front());
1303 auto vectorTy = cast<VectorType>(source.
getType());
1328 build(builder,
result, source, dynamicPos,
1333ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1334 ExtractOp::Adaptor adaptor,
1336 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1337 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1338 vectorType.getRank()) {
1339 inferredReturnTypes.push_back(vectorType.getElementType());
1341 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1342 vectorType.getRank());
1343 inferredReturnTypes.push_back(VectorType::get(
1344 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1345 vectorType.getScalableDims().drop_front(n)));
1353 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1354 return vectorType && vectorType.getShape().equals({1}) &&
1355 vectorType.getElementType() == r.front();
1357 if (l.size() == 1 && r.size() == 1 &&
1358 (isCompatible(l, r) || isCompatible(r, l)))
1363LogicalResult vector::ExtractOp::verify() {
1364 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1365 if (resTy.getRank() == 0)
1367 "expected a scalar instead of a 0-d vector as the result type");
1370 auto dynamicMarkersCount =
1371 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1372 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1374 "mismatch between dynamic and static positions (kDynamic marker but no "
1375 "corresponding dynamic position) -- this can only happen due to an "
1376 "incorrect fold/rewrite");
1377 auto position = getMixedPosition();
1378 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1380 "expected position attribute of rank no greater than vector rank");
1381 for (
auto [idx, pos] : llvm::enumerate(position)) {
1382 if (
auto attr = dyn_cast<Attribute>(pos)) {
1383 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1385 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1386 return emitOpError(
"expected position attribute #")
1388 <<
" to be a non-negative integer smaller than the "
1389 "corresponding vector dimension or poison (-1)";
1396template <
typename IntType>
1398 return llvm::to_vector<4>(llvm::map_range(
1399 arrayAttr.getAsRange<IntegerAttr>(),
1400 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1406 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1410 if (extractOp.hasDynamicPosition())
1414 ExtractOp currentOp = extractOp;
1416 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1417 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1420 if (currentOp.hasDynamicPosition())
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1425 extractOp.setOperand(0, currentOp.getSource());
1428 std::reverse(globalPosition.begin(), globalPosition.end());
1429 extractOp.setStaticPosition(globalPosition);
1441class ExtractFromInsertTransposeChainState {
1443 ExtractFromInsertTransposeChainState(ExtractOp e);
1452 template <
typename ContainerA,
typename ContainerB>
1453 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1454 return a.size() <=
b.size() &&
1455 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1462 template <
typename ContainerA,
typename ContainerB>
1463 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1464 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1465 if (elemA < 0 || elemB < 0)
1476 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1480 void updateStateForNextIteration(Value v) {
1487 LogicalResult handleTransposeOp();
1490 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1505 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1510 Value tryToFoldExtractOpInPlace(Value source);
1512 ExtractOp extractOp;
1514 int64_t extractedRank;
1516 InsertOp nextInsertOp;
1517 TransposeOp nextTransposeOp;
1527 SmallVector<int64_t> sentinels;
1528 SmallVector<int64_t> extractPosition;
1532ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1534 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1535 extractedRank(extractOp.getNumIndices()) {
1536 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1537 sentinels.reserve(vectorRank - extractedRank);
1538 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1539 sentinels.push_back(-(i + 1));
1541 extractOp.getStaticPosition().end());
1547LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1549 if (extractOp.hasDynamicPosition())
1552 if (!nextTransposeOp)
1555 nextTransposeOp.getPermutation(), extractOp.getContext()));
1562ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1565 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1568 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1569 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1572 res = nextInsertOp.getValueToStore();
1581ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1583 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1586 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1596 res = nextInsertOp.getValueToStore();
1604Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1607 if (extractOp.hasDynamicPosition())
1611 bool nothingToFold = (source == extractOp.getSource());
1612 if (nothingToFold || !canFold())
1616 OpBuilder
b(extractOp.getContext());
1617 extractOp.setStaticPosition(
1619 extractOp.getSourceMutable().assign(source);
1620 return extractOp.getResult();
1624Value ExtractFromInsertTransposeChainState::fold() {
1626 if (extractOp.hasDynamicPosition())
1629 Value valueToExtractFrom = extractOp.getSource();
1630 updateStateForNextIteration(valueToExtractFrom);
1631 while (nextInsertOp || nextTransposeOp) {
1634 if (succeeded(handleTransposeOp())) {
1635 valueToExtractFrom = nextTransposeOp.getVector();
1636 updateStateForNextIteration(valueToExtractFrom);
1642 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1647 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1648 return tryToFoldExtractOpInPlace(
result);
1652 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1658 valueToExtractFrom = nextInsertOp.getDest();
1659 updateStateForNextIteration(valueToExtractFrom);
1662 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1667 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1668 auto vecType = dyn_cast<VectorType>(type);
1669 return vecType && vecType.getRank() == 0;
1679 if (isa<BroadcastOp>(op))
1682 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1690 VectorType srcType = shapeCast.getSourceVectorType();
1692 uint64_t srcRank = srcType.getRank();
1694 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1720 Operation *defOp = extractOp.getSource().getDefiningOp();
1727 if (extractOp.getType() == input.
getType())
1733 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1734 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1735 unsigned inputRank = inputType ? inputType.getRank() : 0;
1736 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1737 unsigned extractRank = extractType ? extractType.getRank() : 0;
1740 if (extractRank > inputRank)
1744 assert(inputType &&
"input must be a vector type because of previous checks");
1753 extractType.getShape() != inputShape.take_back(extractRank))
1758 unsigned deltaOverall = inputRank - extractRank;
1759 unsigned deltaBroadcast = broadcastRank - inputRank;
1763 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1764 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1767 extractOp->setOperands(
1768 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1769 extractOp.setStaticPosition(staticPos);
1770 return extractOp.getResult();
1786 if (extractOp.hasDynamicPosition())
1789 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1794 if (shuffleOp.getResultVectorType().getRank() != 1)
1797 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1798 auto shuffleMask = shuffleOp.getMask();
1799 int64_t extractIdx = extractOp.getStaticPosition()[0];
1800 int64_t shuffleIdx = shuffleMask[extractIdx];
1803 if (shuffleIdx < inputVecSize) {
1804 extractOp.setOperand(0, shuffleOp.getV1());
1805 extractOp.setStaticPosition({shuffleIdx});
1807 extractOp.setOperand(0, shuffleOp.getV2());
1808 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1811 return extractOp.getResult();
1817 if (extractOp.hasDynamicPosition())
1820 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1825 auto getDimReverse = [](VectorType type,
int64_t n) {
1826 return type.getShape().take_back(n + 1).front();
1829 llvm::isa<VectorType>(extractOp.getType())
1830 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1832 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1834 if (destinationRank > 0) {
1835 auto destinationType =
1836 llvm::cast<VectorType>(extractOp.getResult().getType());
1837 for (
int64_t i = 0; i < destinationRank; i++) {
1841 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1842 getDimReverse(destinationType, i))
1849 std::reverse(extractedPos.begin(), extractedPos.end());
1852 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1853 strides.push_back(stride);
1855 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1863 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1865 for (
int64_t i = 0; i < numDimension; i++) {
1866 newStrides.push_back(stride);
1868 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1870 std::reverse(newStrides.begin(), newStrides.end());
1874 extractOp.setStaticPosition(newPosition);
1875 extractOp.setOperand(0, shapeCastOp.getSource());
1876 return extractOp.getResult();
1882 if (extractOp.hasDynamicPosition())
1885 auto extractStridedSliceOp =
1886 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1887 if (!extractStridedSliceOp)
1896 if (extractStridedSliceOp.hasNonUnitStrides())
1902 while (!sliceOffsets.empty()) {
1903 size_t lastOffset = sliceOffsets.size() - 1;
1904 if (sliceOffsets.back() != 0 ||
1905 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1906 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1908 sliceOffsets.pop_back();
1910 unsigned destinationRank = 0;
1911 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1912 destinationRank = vecType.getRank();
1915 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1916 sliceOffsets.size())
1920 assert(extractedPos.size() >= sliceOffsets.size());
1921 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1922 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1923 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1927 extractOp.setStaticPosition(extractedPos);
1928 return extractOp.getResult();
1934 if (extractOp.hasDynamicPosition())
1938 llvm::isa<VectorType>(extractOp.getType())
1939 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1941 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1951 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1952 insertOp.getSourceVectorType().getRank();
1953 if (destinationRank > insertOp.getSourceVectorType().getRank())
1958 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1959 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1962 bool disjoint =
false;
1964 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1965 int64_t start = insertOffsets[dim];
1967 (dim < insertRankDiff)
1969 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1971 int64_t offset = extractOffsets[dim];
1973 if (start <= offset && offset < end) {
1974 if (dim >= insertRankDiff)
1975 offsetDiffs.push_back(offset - start);
1986 insertOp.getSourceVectorType().getRank() - destinationRank;
1987 for (
int64_t i = 0; i < destinationRank; i++) {
1988 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1989 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1993 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1996 extractOp.setStaticPosition(offsetDiffs);
1997 return extractOp.getResult();
2001 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2014 if (extractOp.hasDynamicPosition())
2018 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2019 if (!fromElementsOp)
2023 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2024 if (vecType.isScalable())
2028 int64_t rank = vecType.getRank();
2030 if (extractOp.getType() != vecType.getElementType())
2033 "unexpected number of indices");
2038 for (
int i = rank - 1; i >= 0; --i) {
2039 flatIndex +=
indices[i] * stride;
2040 stride *= vecType.getDimSize(i);
2042 return fromElementsOp.getElements()[flatIndex];
2047template <
typename OpType,
typename AdaptorType>
2050 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2051 OperandRange dynamicPosition = op.getDynamicPosition();
2054 if constexpr (std::is_same_v<OpType, ExtractOp>)
2055 vectorShape = op.getSourceVectorType().getShape();
2060 if (!dynamicPosition.size())
2067 bool opChange =
false;
2068 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2069 if (ShapedType::isStatic(staticPosition[i]))
2073 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2074 int64_t value = attr.getInt();
2078 staticPosition[i] = attr.getInt();
2083 operands.push_back(position);
2087 op.setStaticPosition(staticPosition);
2088 op.getOperation()->setOperands(operands);
2090 return op.getResult();
2100 if (!is_contained(staticPos, poisonVal))
2103 return ub::PoisonAttr::get(context);
2108 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2117 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2122 if (denseAttr.isSplat()) {
2124 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2129 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2130 if (vecTy.isScalable())
2133 if (extractOp.hasDynamicPosition()) {
2148 copy(extractOp.getStaticPosition(), completePositions.begin());
2151 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2154 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2156 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2159 newAttr = *denseValuesBegin;
2165OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2169 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2176 SmallVector<Value> operands = {getSource()};
2180 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2186 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2201 return inplaceFolded;
2207class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2211 LogicalResult matchAndRewrite(ExtractOp extractOp,
2212 PatternRewriter &rewriter)
const override {
2215 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2221 BroadcastableToResult::Success)
2230class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2234 LogicalResult matchAndRewrite(ExtractOp extractOp,
2235 PatternRewriter &rewriter)
const override {
2237 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2241 VectorType extractedMaskType =
2242 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2244 if (!extractedMaskType)
2247 auto maskOperands = createMaskOp.getOperands();
2248 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2249 VectorType maskType = createMaskOp.getVectorType();
2251 bool containsUnknownDims =
false;
2254 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2256 int64_t pos = extractOpPos[dimIdx];
2257 Value operand = maskOperands[dimIdx];
2258 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2261 containsUnknownDims =
true;
2265 int64_t createMaskBound =
2266 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2268 if (pos != ShapedType::kDynamic) {
2271 allFalse |= pos >= createMaskBound;
2272 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2276 containsUnknownDims =
true;
2283 }
else if (!containsUnknownDims) {
2285 extractOp, extractedMaskType,
2286 maskOperands.drop_front(extractOpPos.size()));
2296LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2297 PatternRewriter &rewriter) {
2298 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2302 VectorType sourceType = castOp.getSourceVectorType();
2303 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2307 if (sourceType.getNumElements() != targetType.getNumElements())
2311 castOp.getSource());
2321LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2322 PatternRewriter &rewriter) {
2324 if (extractOp.hasDynamicPosition())
2328 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2333 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2334 if (!fromElementsOp)
2336 VectorType inputType = fromElementsOp.getType();
2339 if (resultType.isScalable() || inputType.isScalable())
2344 SmallVector<int64_t> firstElementPos =
2345 llvm::to_vector(extractOp.getStaticPosition());
2346 firstElementPos.append(resultType.getRank(), 0);
2349 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2350 flatIndex += firstElementPos[i] * stride;
2351 stride *= inputType.getDimSize(i);
2356 extractOp, resultType,
2357 fromElementsOp.getElements().slice(flatIndex,
2358 resultType.getNumElements()));
2364void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2365 MLIRContext *context) {
2366 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2367 results.
add(foldExtractFromShapeCastToShapeCast);
2368 results.
add(foldExtractFromFromElements);
2373 for (
auto attr : arrayAttr)
2374 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2381std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2392 if (operands.empty())
2395 return llvm::all_of(operands, [&](
Value operand) {
2397 return currentDef == defOp;
2415 auto fromElementsOp =
2416 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2417 if (!fromElementsOp)
2420 llvm::append_range(results, fromElementsOp.getElements());
2437 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2441 if (isa<VectorType>(bcastOp.getSource().getType()))
2444 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2446 Value scalar = bcastOp.getSource();
2447 results.assign(resultVecType.getNumElements(), scalar);
2451LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2452 SmallVectorImpl<OpFoldResult> &results) {
2459ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2460 ToElementsOp::Adaptor adaptor,
2461 SmallVectorImpl<Type> &inferredReturnTypes) {
2462 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2463 Type elType = vecType.getElementType();
2464 inferredReturnTypes.append(vecType.getNumElements(), elType);
2486 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2491 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2495 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2500 int64_t dstRank = dstShape.size();
2501 int64_t srcRank = srcShape.size();
2504 auto srcElems = vector::ToElementsOp::create(
2505 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2507 int64_t dstCount = llvm::product_of(dstShape);
2510 replacements.reserve(dstCount);
2535 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2538 for (
int64_t k = 0; k < srcRank; ++k)
2539 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2542 replacements.push_back(srcElems.getResult(srcLin));
2545 rewriter.
replaceOp(toElementsOp, replacements);
2550void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2551 MLIRContext *context) {
2552 results.
add<ToElementsOfBroadcast>(context);
2572 OperandRange fromElemsOperands = fromElementsOp.getElements();
2573 if (fromElemsOperands.empty())
2576 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2584 Value toElementsInput = toElementsOp.getSource();
2585 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2586 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2587 return toElementsInput;
2607 if (llvm::any_of(elements, [](
Attribute attr) {
2608 return !attr || isa<ub::PoisonAttrInterface>(attr);
2613 auto destVecType = fromElementsOp.getDest().getType();
2614 auto destEltType = destVecType.getElementType();
2615 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2620 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2627OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2644 if (!llvm::all_equal(fromElementsOp.getElements()))
2647 fromElementsOp, fromElementsOp.getType(),
2648 fromElementsOp.getElements().front());
2676 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2680 if (fromElements.getType().getNumElements() == 1)
2691 for (
auto [insertIndex, element] :
2692 llvm::enumerate(fromElements.getElements())) {
2695 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2698 "element not from vector.extract");
2703 if (insertIndex == 0) {
2704 source = extractOp.getSource();
2705 }
else if (extractOp.getSource() != source) {
2707 "element from different vector");
2711 int64_t rank = position.size();
2712 assert(rank == source.getType().getRank() &&
2713 "scalar extract must have full rank position");
2724 if (insertIndex == 0) {
2725 const int64_t numElms = fromElements.getType().getNumElements();
2728 while (
index > 0 && position[
index - 1] == 0 &&
2729 numSuffixElms < numElms) {
2730 numSuffixElms *= source.getType().getDimSize(
index - 1);
2733 if (numSuffixElms != numElms) {
2735 fromElements,
"elements do not form a suffix of source");
2737 expectedPosition = llvm::to_vector(position);
2738 combinedPosition = position.drop_back(rank -
index);
2742 else if (expectedPosition != position) {
2744 fromElements,
"elements not in ascending order (static order)");
2746 increment(expectedPosition, source.getType().getShape());
2749 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2750 fromElements.getLoc(), source, combinedPosition);
2753 fromElements, fromElements.getType(), extracted);
2761 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2780void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2782 setResultRanges(getResult(), argRanges.front());
2785std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2786 return llvm::to_vector<4>(getResultVectorType().
getShape());
2791static llvm::SetVector<int64_t>
2794 int64_t rankDiff = dstShape.size() - srcShape.size();
2797 for (
auto [s1, s2] :
2798 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2800 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2808llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2810 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2813 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2829Value BroadcastOp::createOrFoldBroadcastOp(
2830 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
2831 const llvm::SetVector<int64_t> &broadcastedDims) {
2832 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2835 SmallVector<int64_t> checkShape;
2836 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2837 if (broadcastedDims.contains(i))
2839 checkShape.push_back(dstShape[i]);
2841 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2842 "ill-formed broadcastedDims contains values not confined to "
2845 Location loc = value.
getLoc();
2847 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2848 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2851 if (!srcVectorType) {
2852 assert(checkShape.empty() &&
2853 "ill-formed createOrFoldBroadcastOp arguments");
2854 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2857 assert(srcVectorType.getShape().equals(checkShape) &&
2858 "ill-formed createOrFoldBroadcastOp arguments");
2868 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2869 broadcastShape.reserve(dstShape.size());
2885 int64_t nextSrcShapeDim = broadcastedDims.size();
2886 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2887 if (broadcastedDims.contains(i)) {
2892 broadcastShape.push_back(dstShape[i]);
2893 permutation[i] = broadcastShape.size() - 1;
2899 permutation[i] = nextSrcShapeDim++;
2903 llvm::append_range(broadcastShape, srcVectorType.getShape());
2908 "unexpected \"dim-1\" broadcast");
2910 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2912 vector::BroadcastableToResult::Success &&
2913 "must be broadcastable");
2914 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2917 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2918 if (permutation[i] != i)
2919 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2925 Type srcType, VectorType dstVectorType,
2926 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2928 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2932 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2936 int64_t srcRank = srcVectorType.getRank();
2937 int64_t dstRank = dstVectorType.getRank();
2938 if (srcRank > dstRank)
2942 int64_t lead = dstRank - srcRank;
2943 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2946 bool foundMismatchingDims =
false;
2949 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2950 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2951 if (srcDim != 1 && srcDim != dstDim)
2952 foundMismatchingDims =
true;
2955 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2956 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2957 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2960 (srcDimScalableFlag != dstDimScalableFlag &&
2961 (srcDim != 1 || srcDimScalableFlag)))
2962 foundMismatchingDims =
true;
2964 if (foundMismatchingDims) {
2965 if (mismatchingDims !=
nullptr) {
2966 mismatchingDims->first.dim = srcDim;
2967 mismatchingDims->first.isScalable = srcDimScalableFlag;
2969 mismatchingDims->second.dim = dstDim;
2970 mismatchingDims->second.isScalable = dstDimScalableFlag;
2979LogicalResult BroadcastOp::verify() {
2980 std::pair<VectorDim, VectorDim> mismatchingDims;
2982 getSourceType(), getResultVectorType(), &mismatchingDims);
2986 return emitOpError(
"source rank higher than destination rank");
2989 << (mismatchingDims.first.isScalable ?
"[" :
"")
2990 << mismatchingDims.first.dim
2991 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2992 << (mismatchingDims.second.isScalable ?
"[" :
"")
2993 << mismatchingDims.second.dim
2994 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2997 return emitOpError(
"source type is not a vector");
2998 llvm_unreachable(
"unexpected vector.broadcast op error");
3005 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3009 VectorType srcType = srcShapeCast.getSourceVectorType();
3010 VectorType destType = broadcastOp.getResultVectorType();
3018 srcShapeCast.getResultVectorType().getShape();
3021 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3022 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3023 shapecastShape.take_back(numTrailingDims)))
3026 assert(all_of(srcShape.drop_back(numTrailingDims),
3027 [](
int64_t E) { return E == 1; }) &&
3028 all_of(shapecastShape.drop_back(numTrailingDims),
3029 [](
int64_t E) { return E == 1; }) &&
3030 "ill-formed shape_cast");
3032 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3036OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3037 if (getSourceType() == getResultVectorType())
3042 if (!adaptor.getSource())
3044 auto vectorType = getResultVectorType();
3045 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3046 if (vectorType.getElementType() != attr.getType())
3050 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3051 if (vectorType.getElementType() != attr.getType())
3055 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3057 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3065struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3068 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3069 PatternRewriter &rewriter)
const override {
3070 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3074 broadcastOp.getResultVectorType(),
3075 srcBroadcast.getSource());
3081void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3082 MLIRContext *context) {
3085 results.
add<BroadcastFolder>(context);
3092LogicalResult ShuffleOp::verify() {
3093 VectorType resultType = getResultVectorType();
3094 VectorType v1Type = getV1VectorType();
3095 VectorType v2Type = getV2VectorType();
3097 int64_t resRank = resultType.getRank();
3098 int64_t v1Rank = v1Type.getRank();
3099 int64_t v2Rank = v2Type.getRank();
3100 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3101 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3102 if (!wellFormed0DCase && !wellFormedNDCase)
3106 for (int64_t r = 1; r < v1Rank; ++r) {
3107 int64_t resDim = resultType.getDimSize(r);
3108 int64_t v1Dim = v1Type.getDimSize(r);
3109 int64_t v2Dim = v2Type.getDimSize(r);
3110 if (resDim != v1Dim || v1Dim != v2Dim)
3114 ArrayRef<int64_t> mask = getMask();
3115 int64_t maskLength = mask.size();
3116 if (maskLength <= 0)
3118 if (maskLength != resultType.getDimSize(0))
3121 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3122 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3123 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3125 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3131ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3132 ShuffleOp::Adaptor adaptor,
3133 SmallVectorImpl<Type> &inferredReturnTypes) {
3134 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3135 auto v1Rank = v1Type.getRank();
3138 SmallVector<int64_t, 4> shape;
3139 shape.reserve(v1Rank);
3140 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3143 llvm::append_range(shape, v1Type.getShape().drop_front());
3144 inferredReturnTypes.push_back(
3145 VectorType::get(shape, v1Type.getElementType()));
3149template <
typename T>
3152 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3153 return value == expected++;
3157OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3158 auto v1Type = getV1VectorType();
3159 auto v2Type = getV2VectorType();
3161 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3162 "Vector shuffle does not support scalable vectors");
3166 if (v1Type.getRank() == 0)
3170 auto mask = getMask();
3177 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3178 if (!v1Attr || !v2Attr)
3182 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3183 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3184 if (isV1Poison && isV2Poison)
3189 if (v1Type.getRank() != 1)
3195 SmallVector<Attribute> v1Elements, v2Elements;
3196 Attribute poisonElement;
3198 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3201 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3202 poisonElement = v2Elements[0];
3205 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3208 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3209 poisonElement = v1Elements[0];
3212 SmallVector<Attribute> results;
3213 int64_t v1Size = v1Type.getDimSize(0);
3214 for (int64_t maskIdx : mask) {
3215 Attribute indexedElm;
3217 if (maskIdx == ShuffleOp::kPoisonIndex) {
3218 indexedElm = poisonElement;
3220 if (maskIdx < v1Size)
3221 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3223 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3226 results.push_back(indexedElm);
3236struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3239 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3240 PatternRewriter &rewriter)
const override {
3241 VectorType v1VectorType = shuffleOp.getV1VectorType();
3242 ArrayRef<int64_t> mask = shuffleOp.getMask();
3243 if (v1VectorType.getRank() > 0)
3245 if (mask.size() != 1)
3247 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3265static Value getScalarSplatSource(Value value) {
3271 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3278 if (isa<VectorType>(
broadcast.getSourceType()))
3286class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3290 LogicalResult matchAndRewrite(ShuffleOp op,
3291 PatternRewriter &rewriter)
const override {
3292 Value splat = getScalarSplatSource(op.getV1());
3293 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3303class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3307 LogicalResult matchAndRewrite(ShuffleOp op,
3308 PatternRewriter &rewriter)
const override {
3309 VectorType resultType = op.getResultVectorType();
3310 if (resultType.isScalable())
3312 op,
"ShuffleOp can't represent a scalable interleave");
3314 if (resultType.getRank() != 1)
3316 op,
"ShuffleOp can't represent an n-D interleave");
3318 VectorType sourceType = op.getV1VectorType();
3319 if (sourceType != op.getV2VectorType() ||
3320 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3322 op,
"ShuffleOp types don't match an interleave");
3325 ArrayRef<int64_t> shuffleMask = op.getMask();
3326 int64_t resultVectorSize = resultType.getNumElements();
3327 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3328 int64_t maskValueA = shuffleMask[i * 2];
3329 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3330 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3332 "ShuffleOp mask not interleaving");
3342void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3343 MLIRContext *context) {
3344 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3352void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3354 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3357void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3358 Value source, Value dest) {
3359 auto vectorTy = cast<VectorType>(dest.
getType());
3360 build(builder,
result, source, dest,
3361 SmallVector<int64_t>(vectorTy.getRank(), 0));
3364void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3365 Value source, Value dest, int64_t position) {
3366 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3369void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3370 Value source, Value dest, OpFoldResult position) {
3371 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3374void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3375 Value source, Value dest,
3376 ArrayRef<int64_t> position) {
3377 SmallVector<OpFoldResult> posVals;
3378 posVals.reserve(position.size());
3379 llvm::transform(position, std::back_inserter(posVals),
3381 build(builder,
result, source, dest, posVals);
3384void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3385 Value source, Value dest,
3386 ArrayRef<OpFoldResult> position) {
3387 SmallVector<int64_t> staticPos;
3388 SmallVector<Value> dynamicPos;
3390 build(builder,
result, source, dest, dynamicPos,
3394LogicalResult InsertOp::verify() {
3395 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3396 if (srcTy.getRank() == 0)
3398 "expected a scalar instead of a 0-d vector as the source operand");
3400 SmallVector<OpFoldResult> position = getMixedPosition();
3401 auto destVectorType = getDestVectorType();
3402 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3404 "expected position attribute of rank no greater than dest vector rank");
3405 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3406 if (srcVectorType &&
3407 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3408 static_cast<unsigned>(destVectorType.getRank())))
3409 return emitOpError(
"expected position attribute rank + source rank to "
3410 "match dest vector rank");
3411 if (!srcVectorType &&
3412 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3414 "expected position attribute rank to match the dest vector rank");
3415 for (
auto [idx, pos] : llvm::enumerate(position)) {
3416 if (
auto attr = dyn_cast<Attribute>(pos)) {
3417 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3419 destVectorType.getDimSize(idx))) {
3420 return emitOpError(
"expected position attribute #")
3422 <<
" to be a non-negative integer smaller than the "
3424 "dest vector dimension";
3437 assert(positions.size() <= completePositions.size() &&
3438 "positions size must be less than or equal to destTy rank");
3439 copy(positions, completePositions.begin());
3447class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3451 LogicalResult matchAndRewrite(InsertOp insertOp,
3452 PatternRewriter &rewriter)
const override {
3454 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3455 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3456 srcVecType.getNumElements())
3459 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3465class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3469 LogicalResult matchAndRewrite(InsertOp op,
3470 PatternRewriter &rewriter)
const override {
3472 Value splat = getScalarSplatSource(op.getValueToStore());
3473 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3501class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3504 LogicalResult matchAndRewrite(InsertOp op,
3505 PatternRewriter &rewriter)
const override {
3507 VectorType destTy = op.getDestVectorType();
3508 if (destTy.isScalable())
3511 for (Operation *user : op.getResult().getUsers())
3512 if (
auto insertOp = dyn_cast<InsertOp>(user))
3513 if (insertOp.getDest() == op.getResult())
3516 InsertOp currentOp = op;
3517 SmallVector<InsertOp> chainInsertOps;
3520 if (currentOp.hasDynamicPosition())
3523 chainInsertOps.push_back(currentOp);
3524 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3527 if (currentOp && !currentOp->hasOneUse())
3531 int64_t vectorSize = destTy.getNumElements();
3532 int64_t initializedCount = 0;
3533 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3534 SmallVector<int64_t> pendingInsertPos;
3535 SmallVector<int64_t> pendingInsertSize;
3536 SmallVector<Value> pendingInsertValues;
3538 for (
auto insertOp : chainInsertOps) {
3540 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3544 int64_t insertBeginPosition =
3549 int64_t insertSize = 1;
3550 if (
auto srcVectorType =
3551 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3552 insertSize = srcVectorType.getNumElements();
3554 assert(insertBeginPosition + insertSize <= vectorSize &&
3555 "insert would overflow the vector");
3557 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3558 insertBeginPosition + insertSize)) {
3559 if (initializedDestIdxs[index])
3561 initializedDestIdxs[index] =
true;
3567 pendingInsertPos.push_back(insertBeginPosition);
3568 pendingInsertSize.push_back(insertSize);
3569 pendingInsertValues.push_back(insertOp.getValueToStore());
3571 if (initializedCount == vectorSize)
3576 if (initializedCount != vectorSize)
3579 SmallVector<Value> elements(vectorSize);
3580 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3581 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3582 pendingInsertValues))) {
3583 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3585 if (!srcVectorType) {
3586 elements[insertBeginPosition] = valueToStore;
3590 SmallVector<Type> elementToInsertTypes(insertSize,
3591 srcVectorType.getElementType());
3593 auto elementsToInsert = vector::ToElementsOp::create(
3594 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3595 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3596 elements[insertBeginPosition + linearIdx] =
3597 elementsToInsert.getResult(linearIdx);
3611 int64_t maxVectorSizeFoldThreshold) {
3612 if (insertOp.hasDynamicPosition())
3615 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3623 VectorType destTy = insertOp.getDestVectorType();
3624 if (destTy.isScalable())
3628 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3629 !insertOp->hasOneUse())
3636 Type destEltType = destTy.getElementType();
3640 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3641 for (
auto value : denseSource.getValues<
Attribute>())
3647 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3648 copy(insertedValues, allValues.begin() + insertBeginPosition);
3657 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3661 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3664 insertOp.
setOperand(1, destInsert.getDest());
3665 return insertOp.getResult();
3668void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3669 MLIRContext *context) {
3670 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3671 InsertChainFullyInitialized>(context);
3674OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3677 constexpr int64_t vectorSizeFoldThreshold = 256;
3681 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3682 return getValueToStore();
3686 SmallVector<Value> operands = {getValueToStore(), getDest()};
3692 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3695 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3696 vectorSizeFoldThreshold)) {
3700 return inplaceFolded;
3707void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
3708 Value source, Value dest,
3709 ArrayRef<int64_t> offsets,
3710 ArrayRef<int64_t> strides) {
3711 result.addOperands({source, dest});
3715 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
3717 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
3722template <
typename OpType>
3726 StringRef attrName) {
3727 if (arrayAttr.size() >
shape.size())
3728 return op.emitOpError(
"expected ")
3729 << attrName <<
" attribute of rank no greater than vector rank";
3736template <
typename OpType>
3740 bool halfOpen =
true) {
3741 for (
auto attr : arrayAttr) {
3742 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3746 if (val < min || val >= upper)
3747 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3748 <<
min <<
", " << upper <<
")";
3756template <
typename OpType>
3761 for (
auto [
index, attrDimPair] :
3762 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
3763 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3767 if (val < min || val >=
max)
3768 return op.emitOpError(
"expected ")
3769 << attrName <<
" dimension " <<
index <<
" to be confined to ["
3770 <<
min <<
", " <<
max <<
")";
3780template <
typename OpType>
3785 assert(arrayAttr1.size() <=
shape.size());
3786 assert(arrayAttr2.size() <=
shape.size());
3787 for (
auto [
index, it] :
3788 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
3789 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3790 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3794 if (val1 + val2 < 0 || val1 + val2 >=
max)
3795 return op.emitOpError(
"expected sum(")
3796 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
3797 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3805 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3807 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3810LogicalResult InsertStridedSliceOp::verify() {
3811 auto sourceVectorType = getSourceVectorType();
3812 auto destVectorType = getDestVectorType();
3813 auto offsets = getOffsetsAttr();
3814 auto strides = getStridesAttr();
3815 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3817 "expected offsets of same size as destination vector rank");
3818 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3819 return emitOpError(
"expected strides of same size as source vector rank");
3820 if (sourceVectorType.getRank() > destVectorType.getRank())
3822 "expected source rank to be no greater than destination rank");
3824 auto sourceShape = sourceVectorType.getShape();
3825 auto destShape = destVectorType.getShape();
3826 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3827 destShape.size() - sourceShape.size(), 0);
3828 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3829 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3830 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3839 offName,
"source vector shape",
3843 unsigned rankDiff = destShape.size() - sourceShape.size();
3844 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3845 if (sourceVectorType.getScalableDims()[idx] !=
3846 destVectorType.getScalableDims()[idx + rankDiff]) {
3847 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3850 if (sourceVectorType.getScalableDims()[idx]) {
3851 auto sourceSize = sourceShape[idx];
3852 auto destSize = destShape[idx + rankDiff];
3853 if (sourceSize != destSize) {
3856 << (
" to match the corresponding base size from the input "
3858 << sourceSize << (
" vs ") << destSize << (
")");
3868class FoldInsertStridedSliceSplat final
3869 :
public OpRewritePattern<InsertStridedSliceOp> {
3873 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3874 PatternRewriter &rewriter)
const override {
3876 auto dst = insertStridedSliceOp.getDest();
3877 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3878 if (!splat || getScalarSplatSource(dst) != splat)
3881 rewriter.
replaceOp(insertStridedSliceOp, dst);
3888class FoldInsertStridedSliceOfExtract final
3889 :
public OpRewritePattern<InsertStridedSliceOp> {
3893 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3894 PatternRewriter &rewriter)
const override {
3895 auto extractStridedSliceOp =
3896 insertStridedSliceOp.getValueToStore()
3897 .getDefiningOp<vector::ExtractStridedSliceOp>();
3899 if (!extractStridedSliceOp)
3902 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3906 if (extractStridedSliceOp.getStrides() !=
3907 insertStridedSliceOp.getStrides() ||
3908 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3911 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3918class InsertStridedSliceConstantFolder final
3919 :
public OpRewritePattern<InsertStridedSliceOp> {
3925 static constexpr int64_t vectorSizeFoldThreshold = 256;
3927 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3928 PatternRewriter &rewriter)
const override {
3932 Attribute vectorDestCst;
3936 VectorType destTy = destVector.getType();
3937 if (destTy.isScalable())
3941 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3942 !destVector.hasOneUse())
3946 Attribute sourceCst;
3951 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3955 if (op.hasNonUnitStrides())
3958 VectorType sliceVecTy = sourceValue.getType();
3959 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3960 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3961 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
3962 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
3970 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3971 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3972 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3973 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3974 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3975 MutableArrayRef<int64_t> currSlicePosition(
3976 currDestPosition.begin() + rankDifference, currDestPosition.end());
3977 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3980 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3981 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3982 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3983 "Invalid slice element");
3984 newValues[linearizedPosition] = *sliceValuesIt;
3997void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3998 RewritePatternSet &results, MLIRContext *context) {
3999 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4000 InsertStridedSliceConstantFolder>(context);
4003OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4004 if (getSourceVectorType() == getDestVectorType())
4005 return getValueToStore();
4014void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4015 Value
lhs, Value
rhs, Value acc) {
4020void OuterProductOp::print(OpAsmPrinter &p) {
4021 p <<
" " << getLhs() <<
", " << getRhs();
4023 p <<
", " << getAcc();
4026 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4029ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4030 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4037 if (operandsInfo.size() < 2)
4039 "expected at least 2 operands");
4040 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4041 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4044 "expected vector type for operand #1");
4048 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4049 vRHS.getScalableDims()[0]};
4050 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4051 vLHS.getElementType(), scalableDimsRes);
4054 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4055 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4059 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4060 result.attributes.append(
4061 OuterProductOp::getKindAttrName(
result.name),
4062 CombiningKindAttr::get(
result.getContext(),
4063 OuterProductOp::getDefaultKind()));
4069 (operandsInfo.size() > 2 &&
4074LogicalResult OuterProductOp::verify() {
4075 Type tRHS = getOperandTypeRHS();
4076 VectorType vLHS = getOperandVectorTypeLHS(),
4077 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4078 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4080 if (vLHS.getRank() != 1)
4081 return emitOpError(
"expected 1-d vector for operand #1");
4085 if (vRHS.getRank() != 1)
4086 return emitOpError(
"expected 1-d vector for operand #2");
4087 if (vRES.getRank() != 2)
4089 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4090 return emitOpError(
"expected #1 operand dim to match result dim #1");
4091 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4092 return emitOpError(
"expected #2 operand dim to match result dim #2");
4093 if (vLHS.isScalable() && !vRHS.isScalable()) {
4097 "expected either both or only #2 operand dim to be scalable");
4101 if (vRES.getRank() != 1)
4103 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4104 return emitOpError(
"expected #1 operand dim to match result dim #1");
4107 if (vACC && vACC != vRES)
4108 return emitOpError(
"expected operand #3 of same type as result type");
4112 return emitOpError(
"unsupported outerproduct type");
4121Type OuterProductOp::getExpectedMaskType() {
4122 auto vecType = this->getResultVectorType();
4123 return VectorType::get(vecType.getShape(),
4124 IntegerType::get(vecType.getContext(), 1),
4125 vecType.getScalableDims());
4139 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4141 shape.reserve(vectorType.getRank());
4143 for (
unsigned e = offsets.size(); idx < e; ++idx)
4144 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4145 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4146 shape.push_back(vectorType.getShape()[idx]);
4148 return VectorType::get(
shape, vectorType.getElementType(),
4149 vectorType.getScalableDims());
4152void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4153 Value source, ArrayRef<int64_t> offsets,
4154 ArrayRef<int64_t> sizes,
4155 ArrayRef<int64_t> strides) {
4156 result.addOperands(source);
4162 offsetsAttr, sizesAttr, stridesAttr));
4163 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4165 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4167 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4171LogicalResult ExtractStridedSliceOp::verify() {
4172 auto type = getSourceVectorType();
4173 auto offsets = getOffsetsAttr();
4174 auto sizes = getSizesAttr();
4175 auto strides = getStridesAttr();
4176 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4178 "expected offsets, sizes and strides attributes of same size");
4180 auto shape = type.getShape();
4181 auto offName = getOffsetsAttrName();
4182 auto sizesName = getSizesAttrName();
4183 auto stridesName = getStridesAttrName();
4199 shape, offName, sizesName,
4204 offsets, sizes, strides);
4205 if (getResult().
getType() != resultType)
4206 return emitOpError(
"expected result type to be ") << resultType;
4208 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4209 if (type.getScalableDims()[idx]) {
4210 auto inputDim = type.getShape()[idx];
4211 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4212 if (inputDim != inputSize)
4215 << (
" to match the corresponding base size from the input "
4217 << inputSize << (
" vs ") << inputDim << (
")");
4230 auto getElement = [](
ArrayAttr array,
int idx) {
4231 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4233 ArrayAttr extractOffsets = op.getOffsets();
4236 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4238 if (op.getSourceVectorType().getRank() !=
4239 insertOp.getSourceVectorType().getRank())
4241 ArrayAttr insertOffsets = insertOp.getOffsets();
4242 ArrayAttr insertStrides = insertOp.getStrides();
4245 if (extractOffsets.size() > insertOffsets.size())
4247 bool patialoverlap =
false;
4248 bool disjoint =
false;
4250 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4251 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4253 int64_t start = getElement(insertOffsets, dim);
4254 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4255 int64_t offset = getElement(extractOffsets, dim);
4256 int64_t size = getElement(extractSizes, dim);
4258 if (start <= offset && offset < end) {
4261 if (offset + size > end)
4262 patialoverlap =
true;
4263 offsetDiffs.push_back(offset - start);
4270 if (!disjoint && !patialoverlap) {
4271 op.setOperand(insertOp.getValueToStore());
4274 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4280 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4295 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4300 if (op.hasNonUnitStrides())
4303 VectorType sourceVecTy = op.getSourceVectorType();
4307 VectorType sliceVecTy = op.getType();
4309 int64_t rank = sliceVecTy.getRank();
4321 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4323 sliceValues.reserve(sliceVecTy.getNumElements());
4327 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4329 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4330 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4332 assert(
static_cast<int64_t>(sliceValues.size()) ==
4333 sliceVecTy.getNumElements() &&
4334 "Invalid number of slice elements");
4338OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4339 if (getSourceVectorType() == getResult().
getType())
4346 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4353void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4371class StridedSliceCreateMaskFolder final
4372 :
public OpRewritePattern<ExtractStridedSliceOp> {
4376 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4377 PatternRewriter &rewriter)
const override {
4378 Location loc = extractStridedSliceOp.getLoc();
4382 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4386 if (extractStridedSliceOp.hasNonUnitStrides())
4389 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4391 SmallVector<int64_t> sliceOffsets;
4394 SmallVector<int64_t> sliceSizes;
4398 SmallVector<Value> sliceMaskDimSizes;
4399 sliceMaskDimSizes.reserve(maskDimSizes.size());
4403 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4404 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4408 IntegerAttr offsetAttr =
4410 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4411 Value sliceMaskDimSize =
4412 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4413 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4418 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4422 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4430class StridedSliceConstantMaskFolder final
4431 :
public OpRewritePattern<ExtractStridedSliceOp> {
4435 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4436 PatternRewriter &rewriter)
const override {
4439 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4440 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4441 if (!constantMaskOp)
4444 if (extractStridedSliceOp.hasNonUnitStrides())
4447 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4449 SmallVector<int64_t> sliceOffsets;
4452 SmallVector<int64_t> sliceSizes;
4456 SmallVector<int64_t> sliceMaskDimSizes;
4457 sliceMaskDimSizes.reserve(maskDimSizes.size());
4458 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4459 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4460 int64_t sliceMaskDimSize = std::max(
4461 static_cast<int64_t
>(0),
4462 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4463 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4466 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4467 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4468 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4471 if (llvm::is_contained(sliceMaskDimSizes, 0))
4472 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4477 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4485class StridedSliceBroadcast final
4486 :
public OpRewritePattern<ExtractStridedSliceOp> {
4490 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4491 PatternRewriter &rewriter)
const override {
4497 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4498 auto dstVecType = llvm::cast<VectorType>(op.getType());
4499 unsigned dstRank = dstVecType.getRank();
4500 unsigned rankDiff = dstRank - srcRank;
4504 bool needsSlice =
false;
4505 for (
unsigned i = 0; i < srcRank; i++) {
4506 if (srcVecType.getDimSize(i) != 1 &&
4507 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4514 SmallVector<int64_t> offsets =
4516 SmallVector<int64_t> sizes =
4518 for (
unsigned i = 0; i < srcRank; i++) {
4519 if (srcVecType.getDimSize(i) == 1) {
4527 source = ExtractStridedSliceOp::create(
4528 rewriter, op->getLoc(), source, offsets, sizes,
4537class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4541 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4542 PatternRewriter &rewriter)
const override {
4544 Value splat = getScalarSplatSource(op.getSource());
4568class ContiguousExtractStridedSliceToExtract final
4569 :
public OpRewritePattern<ExtractStridedSliceOp> {
4573 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4574 PatternRewriter &rewriter)
const override {
4575 if (op.hasNonUnitStrides())
4577 Value source = op.getOperand();
4578 auto sourceType = cast<VectorType>(source.
getType());
4579 if (sourceType.isScalable() || sourceType.getRank() == 0)
4588 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4589 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4596 if (numOffsets == 0)
4601 if (numOffsets == sourceType.getRank() &&
4602 static_cast<int>(sizes.size()) == sourceType.getRank())
4606 for (
int i = 0; i < numOffsets; ++i) {
4614 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4615 sizes[numOffsets] == 1) {
4620 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4621 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4630void ExtractStridedSliceOp::getCanonicalizationPatterns(
4631 RewritePatternSet &results, MLIRContext *context) {
4634 results.
add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4635 StridedSliceBroadcast, StridedSliceSplat,
4636 ContiguousExtractStridedSliceToExtract>(context);
4644void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4645 VectorType vectorType, Value source,
4647 AffineMapAttr permutationMapAttr,
4650 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4652 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4653 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4654 *padding, Value(), inBoundsAttr);
4658void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4659 VectorType vectorType, Value source,
4661 AffineMap permutationMap,
4662 std::optional<ArrayRef<bool>> inBounds) {
4663 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4664 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4667 SmallVector<bool>(vectorType.getRank(),
false));
4668 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4670 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4671 build(builder,
result, vectorType, source,
indices, *padding,
4672 permutationMapAttr, inBoundsAttr);
4676void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4677 VectorType vectorType, Value source,
4679 std::optional<ArrayRef<bool>> inBounds) {
4681 llvm::cast<ShapedType>(source.
getType()), vectorType);
4682 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4683 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4686 SmallVector<bool>(vectorType.getRank(),
false));
4687 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4689 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4690 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4692 Value(), inBoundsAttr);
4695template <
typename EmitFun>
4699 for (
auto expr : permutationMap.
getResults()) {
4700 auto dim = dyn_cast<AffineDimExpr>(expr);
4701 auto zero = dyn_cast<AffineConstantExpr>(expr);
4703 if (zero.getValue() != 0) {
4705 "requires a projected permutation_map (at most one dim or the zero "
4706 "constant can appear in each result)");
4711 return emitOpError(
"requires a projected permutation_map (at most one "
4712 "dim or the zero constant can appear in each result)");
4714 if (seen[dim.getPosition()]) {
4716 "requires a permutation_map that is a permutation (found one dim "
4717 "used more than once)");
4719 seen[dim.getPosition()] =
true;
4726 VectorType vectorType, VectorType maskType,
4727 VectorType inferredMaskType,
AffineMap permutationMap,
4729 if (op->hasAttr(
"masked")) {
4730 return op->emitOpError(
"masked attribute has been removed. "
4731 "Use in_bounds instead.");
4734 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4735 return op->emitOpError(
4736 "requires source to be a memref or ranked tensor type");
4738 auto elementType = shapedType.getElementType();
4740 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4742 unsigned sourceVecSize =
4744 vectorElementType.getShape().back();
4745 unsigned resultVecSize =
4747 vectorType.getShape().back();
4748 if (resultVecSize % sourceVecSize != 0)
4749 return op->emitOpError(
4750 "requires the bitwidth of the minor 1-D vector to be an integral "
4751 "multiple of the bitwidth of the minor 1-D vector of the source");
4753 unsigned sourceVecEltRank = vectorElementType.getRank();
4754 unsigned resultVecRank = vectorType.getRank();
4755 if (sourceVecEltRank > resultVecRank)
4756 return op->emitOpError(
4757 "requires source vector element and vector result ranks to match.");
4758 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4761 return op->emitOpError(
"requires a permutation_map with result dims of "
4762 "the same rank as the vector type");
4765 return op->emitOpError(
"does not support masks with vector element type");
4768 unsigned minorSize =
4769 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4770 unsigned resultVecSize =
4773 return op->emitOpError(
4774 "requires the bitwidth of the minor 1-D vector to be an integral "
4775 "multiple of the bitwidth of the source element type");
4779 return op->emitOpError(
"requires a permutation_map with result dims of "
4780 "the same rank as the vector type");
4784 return op->emitOpError(
"requires permutation_map without symbols");
4786 if (permutationMap.
getNumInputs() != shapedType.getRank())
4787 return op->emitOpError(
"requires a permutation_map with input dims of the "
4788 "same rank as the source type");
4790 if (maskType && maskType != inferredMaskType)
4791 return op->emitOpError(
"inferred mask type (")
4792 << inferredMaskType <<
") and mask operand type (" << maskType
4796 return op->emitOpError(
"expects the in_bounds attr of same rank "
4797 "as permutation_map results: ")
4798 << AffineMapAttr::get(permutationMap)
4799 <<
" vs inBounds of size: " << inBounds.size();
4806 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4807 if (op.getPermutationMap().isMinorIdentity())
4808 elidedAttrs.push_back(op.getPermutationMapAttrName());
4810 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4811 elidedAttrs.push_back(op.getInBoundsAttrName());
4815void TransferReadOp::print(OpAsmPrinter &p) {
4818 p <<
", " << getMask();
4825 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
4827 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4832 if (maskShape.empty())
4833 maskShape.push_back(1);
4838 return VectorType::get(maskShape, i1Type, scalableDims);
4855 if (hasMask.succeeded()) {
4862 if (types.size() != 2)
4863 return parser.
emitError(typesLoc,
"requires two types");
4865 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4866 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4867 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4868 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4870 return parser.
emitError(typesLoc,
"requires vector type");
4871 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
4875 if (shapedType.getRank() <
4878 "expected a custom permutation_map when "
4879 "rank(source) != rank(destination)");
4881 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4883 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4885 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
4886 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
4887 if (!inBoundsAttr) {
4888 result.addAttribute(inBoundsAttrName,
4897 if (hasMask.succeeded()) {
4898 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4900 maskInfo.
location,
"does not support masks with vector element type");
4903 "expected the same rank for the vector and the "
4904 "results of the permutation map");
4912 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4914 {1, static_cast<int32_t>(indexInfo.size()), 1,
4915 static_cast<int32_t>(hasMask.succeeded())}));
4919LogicalResult TransferReadOp::verify() {
4921 ShapedType shapedType = getShapedType();
4923 VectorType maskType = getMaskType();
4924 auto paddingType = getPadding().getType();
4925 auto permutationMap = getPermutationMap();
4926 VectorType inferredMaskType =
4929 auto sourceElementType = shapedType.getElementType();
4931 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4932 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4935 shapedType, vectorType, maskType,
4936 inferredMaskType, permutationMap, getInBounds())))
4939 if (
auto sourceVectorElementType =
4940 llvm::dyn_cast<VectorType>(sourceElementType)) {
4943 if (sourceVectorElementType != paddingType)
4945 "requires source element type and padding type to match.");
4949 if (!VectorType::isValidElementType(paddingType))
4950 return emitOpError(
"requires valid padding vector elemental type");
4953 if (paddingType != sourceElementType)
4955 "requires formal padding and source of the same elemental type");
4966Type TransferReadOp::getExpectedMaskType() {
4973VectorType TransferReadOp::getVectorType() {
4974 return cast<VectorType>(getVector().
getType());
4977template <
typename TransferOp>
4981 if (op.getShapedType().isDynamicDim(indicesIdx))
4985 if (!cstOp.has_value())
4988 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4989 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4991 return cstOp.value() + vectorSize <= sourceSize;
4994template <
typename TransferOp>
4998 if (op.getTransferRank() == 0)
5003 newInBounds.reserve(op.getTransferRank());
5008 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5010 if (op.isDimInBounds(i)) {
5011 newInBounds.push_back(
true);
5016 bool inBounds =
false;
5017 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5020 dimExpr.getPosition());
5021 nonBcastDims.push_back(i);
5024 newInBounds.push_back(inBounds);
5032 bool allNonBcastDimsInBounds = llvm::all_of(
5033 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5034 if (allNonBcastDimsInBounds) {
5037 newInBounds[idx] =
true;
5045 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5049template <
typename TransferOp>
5051 auto mask = op.getMask();
5058 op.getMaskMutable().clear();
5072static Value foldRAW(TransferReadOp readOp) {
5073 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5075 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5078 return defWrite.getVector();
5080 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5081 cast<VectorTransferOpInterface>(readOp.getOperation())))
5083 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5088OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5089 if (Value vec = foldRAW(*
this))
5100 return OpFoldResult();
5103std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5107void TransferReadOp::getEffects(
5108 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5110 if (llvm::isa<MemRefType>(getShapedType()))
5111 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5112 SideEffects::DefaultResource::get());
5116 if (hasPureTensorSemantics())
5123static AffineMap inverseWithUnusedDims(AffineMap map) {
5125 "expected a projected permutation map");
5130 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5160struct TransferReadAfterWriteToBroadcast
5161 :
public OpRewritePattern<TransferReadOp> {
5164 LogicalResult matchAndRewrite(TransferReadOp readOp,
5165 PatternRewriter &rewriter)
const override {
5166 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5170 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5174 if (readOp.getMask() || defWrite.getMask())
5177 if (readOp.getIndices() != defWrite.getIndices())
5180 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5184 if (readOp.getTransferChunkAccessed() !=
5185 defWrite.getTransferChunkAccessed())
5192 AffineMap readMap = readOp.getPermutationMap();
5193 AffineMap writeMap = defWrite.getPermutationMap();
5194 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5195 AffineMap composedMap = readMap.
compose(invWriteMap);
5209 int64_t numBroadcastedDims = broadcastedDims.size();
5210 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5212 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5213 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5214 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5215 invPerm[effectiveDim] = idx;
5220 VectorType readVecTy = readOp.getVectorType();
5222 auto broadcastedVecTy =
5224 readVecTy.getElementType(),
5227 Value vec = defWrite.getVector();
5228 Location loc = readOp.getLoc();
5229 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5236void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5237 MLIRContext *context) {
5238 results.
add<TransferReadAfterWriteToBroadcast>(context);
5241FailureOr<std::optional<SmallVector<Value>>>
5242TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5243 if (!hasPureBufferSemantics())
5254void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5256 AffineMapAttr permutationMapAttr,
5259 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5260 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5261 mask, inBoundsAttr);
5265void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5267 AffineMapAttr permutationMapAttr,
5269 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5270 Value(), inBoundsAttr);
5275void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5277 AffineMap permutationMap,
5278 std::optional<ArrayRef<bool>> inBounds) {
5279 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5281 (inBounds && !inBounds.value().empty())
5284 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5285 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5286 Value(), inBoundsAttr);
5291void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5293 std::optional<ArrayRef<bool>> inBounds) {
5294 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5296 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5297 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5300ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5301 OperationState &
result) {
5304 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5305 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5306 SmallVector<Type, 2> types;
5307 OpAsmParser::UnresolvedOperand maskInfo;
5313 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5318 if (types.size() != 2)
5319 return parser.
emitError(typesLoc,
"requires two types");
5321 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5323 return parser.
emitError(typesLoc,
"requires vector type");
5324 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5325 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5326 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5327 auto permMapAttrName =
5328 TransferWriteOp::getPermutationMapAttrName(
result.name);
5329 auto permMapAttr =
result.attributes.get(permMapAttrName);
5332 if (shapedType.getRank() <
5335 "expected a custom permutation_map when "
5336 "rank(source) != rank(destination)");
5338 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5340 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5342 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5343 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5344 if (!inBoundsAttr) {
5345 result.addAttribute(inBoundsAttrName,
5353 if (hasMask.succeeded()) {
5354 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5356 maskInfo.
location,
"does not support masks with vector element type");
5359 "expected the same rank for the vector and the "
5360 "results of the permutation map");
5366 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5368 {1, 1, static_cast<int32_t>(indexInfo.size()),
5369 static_cast<int32_t>(hasMask.succeeded())}));
5370 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5374void TransferWriteOp::print(OpAsmPrinter &p) {
5377 p <<
", " << getMask();
5382LogicalResult TransferWriteOp::verify() {
5384 ShapedType shapedType = getShapedType();
5386 VectorType maskType = getMaskType();
5387 auto permutationMap = getPermutationMap();
5388 VectorType inferredMaskType =
5392 if (llvm::size(
getIndices()) != shapedType.getRank())
5393 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5397 if (hasBroadcastDim())
5398 return emitOpError(
"should not have broadcast dimensions");
5401 shapedType, vectorType, maskType,
5402 inferredMaskType, permutationMap, getInBounds())))
5415Type TransferWriteOp::getExpectedMaskType() {
5422Value TransferWriteOp::getVector() {
return getOperand(0); }
5423VectorType TransferWriteOp::getVectorType() {
5424 return cast<VectorType>(getValueToStore().
getType());
5447static LogicalResult foldReadInitWrite(TransferWriteOp write,
5448 ArrayRef<Attribute>,
5449 SmallVectorImpl<OpFoldResult> &results) {
5451 if (write.getTransferRank() == 0)
5453 auto rankedTensorType =
5454 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5456 if (!rankedTensorType)
5459 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5463 if (read.getTransferRank() == 0)
5466 if (!read.getPermutationMap().isMinorIdentity() ||
5467 !write.getPermutationMap().isMinorIdentity())
5470 if (read.getTransferRank() != write.getTransferRank())
5473 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5476 if (read.getBase().getType() != rankedTensorType)
5479 if (read.getVectorType() != write.getVectorType())
5482 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5485 auto isNotConstantZero = [](Value v) {
5487 return !cstOp.has_value() || cstOp.value() != 0;
5489 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5490 llvm::any_of(write.getIndices(), isNotConstantZero))
5493 results.push_back(read.getBase());
5497static bool checkSameValueWAR(vector::TransferReadOp read,
5498 vector::TransferWriteOp write) {
5499 return read.getBase() == write.getBase() &&
5500 read.getIndices() == write.getIndices() &&
5501 read.getPermutationMap() == write.getPermutationMap() &&
5502 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5519static LogicalResult foldWAR(TransferWriteOp write,
5520 SmallVectorImpl<OpFoldResult> &results) {
5521 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5523 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5527 if (!checkSameValueWAR(read, write))
5529 results.push_back(read.getBase());
5533LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5534 SmallVectorImpl<OpFoldResult> &results) {
5535 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5537 if (succeeded(foldWAR(*
this, results)))
5549std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5553void TransferWriteOp::getEffects(
5554 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5556 if (llvm::isa<MemRefType>(getShapedType()))
5557 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5558 SideEffects::DefaultResource::get());
5562 if (hasPureTensorSemantics())
5592class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
5595 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5596 PatternRewriter &rewriter)
const override {
5597 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5599 vector::TransferWriteOp writeToModify = writeOp;
5601 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5605 writeToModify.getBaseMutable().assign(defWrite.getBase());
5610 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5611 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5615 if (!defWrite->hasOneUse())
5617 writeToModify = defWrite;
5618 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5647struct SwapExtractSliceOfTransferWrite
5648 :
public OpRewritePattern<tensor::InsertSliceOp> {
5652 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5653 PatternRewriter &rewriter)
const override {
5654 if (!insertOp.hasUnitStride())
5657 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5658 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5660 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5661 if (!transferOp || !transferOp->hasOneUse())
5666 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5668 "use-def chain is rank-reducing");
5672 if (!extractOp.hasZeroOffset()) {
5674 "ExtractSliceOp has non-zero offset");
5678 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5679 return getConstantIntValue(value) == static_cast<int64_t>(0);
5682 "TranferWriteOp has non-zero offset");
5686 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5688 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5691 for (
auto [insertSize, extractSize] :
5692 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5695 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5700 assert(transferOp.getVectorType().hasStaticShape() &&
5701 "expected vector to have a static shape");
5702 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
5704 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5705 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5707 insertOp,
"TransferWriteOp may not write the full tensor.");
5712 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
5713 auto newExtractOp = tensor::ExtractSliceOp::create(
5714 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5715 insertOp.getDest(), insertOp.getMixedOffsets(),
5716 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5717 auto newTransferWriteOp = TransferWriteOp::create(
5718 rewriter, transferOp.getLoc(), transferOp.getVector(),
5719 newExtractOp.getResult(), transferOp.getIndices(),
5720 transferOp.getPermutationMapAttr(),
5723 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5731void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5732 MLIRContext *context) {
5733 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5736FailureOr<std::optional<SmallVector<Value>>>
5737TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5738 if (!hasPureBufferSemantics())
5748static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5750 MemRefType memRefTy) {
5753 if (!vecTy.isScalable() &&
5754 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5757 if (!memRefTy.isLastDimUnitStride())
5758 return op->
emitOpError(
"most minor memref dim must have unit stride");
5762LogicalResult vector::LoadOp::verify() {
5766 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5769 if (memRefTy.getRank() < resVecTy.getRank())
5771 "destination memref has lower rank than the result vector");
5774 Type memElemTy = memRefTy.getElementType();
5775 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5776 if (memVecTy != resVecTy)
5777 return emitOpError(
"base memref and result vector types should match");
5778 memElemTy = memVecTy.getElementType();
5781 if (resVecTy.getElementType() != memElemTy)
5782 return emitOpError(
"base and result element types should match");
5783 if (llvm::size(
getIndices()) != memRefTy.getRank())
5784 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5788OpFoldResult LoadOp::fold(FoldAdaptor) {
5791 return OpFoldResult();
5794std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5798FailureOr<std::optional<SmallVector<Value>>>
5799LoadOp::bubbleDownCasts(OpBuilder &builder) {
5808LogicalResult vector::StoreOp::verify() {
5812 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5815 if (memRefTy.getRank() < valueVecTy.getRank())
5816 return emitOpError(
"source memref has lower rank than the vector to store");
5819 Type memElemTy = memRefTy.getElementType();
5820 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5821 if (memVecTy != valueVecTy)
5823 "base memref and valueToStore vector types should match");
5824 memElemTy = memVecTy.getElementType();
5827 if (valueVecTy.getElementType() != memElemTy)
5828 return emitOpError(
"base and valueToStore element type should match");
5829 if (llvm::size(
getIndices()) != memRefTy.getRank())
5830 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5834LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5835 SmallVectorImpl<OpFoldResult> &results) {
5839std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5843FailureOr<std::optional<SmallVector<Value>>>
5844StoreOp::bubbleDownCasts(OpBuilder &builder) {
5853LogicalResult MaskedLoadOp::verify() {
5854 VectorType maskVType = getMaskVectorType();
5855 VectorType passVType = getPassThruVectorType();
5859 if (resVType.getElementType() != memType.getElementType())
5860 return emitOpError(
"base and result element type should match");
5861 if (llvm::size(
getIndices()) != memType.getRank())
5862 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5863 if (resVType.getShape() != maskVType.getShape())
5864 return emitOpError(
"expected result shape to match mask shape");
5865 if (resVType != passVType)
5866 return emitOpError(
"expected pass_thru of same type as result type");
5871class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
5874 LogicalResult matchAndRewrite(MaskedLoadOp
load,
5875 PatternRewriter &rewriter)
const override {
5887 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5892void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5893 MLIRContext *context) {
5894 results.
add<MaskedLoadFolder>(context);
5897OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5900 return OpFoldResult();
5903FailureOr<std::optional<SmallVector<Value>>>
5904MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
5913LogicalResult MaskedStoreOp::verify() {
5914 VectorType maskVType = getMaskVectorType();
5918 if (valueVType.getElementType() != memType.getElementType())
5919 return emitOpError(
"base and valueToStore element type should match");
5920 if (llvm::size(
getIndices()) != memType.getRank())
5921 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5922 if (valueVType.getShape() != maskVType.getShape())
5923 return emitOpError(
"expected valueToStore shape to match mask shape");
5928class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
5931 LogicalResult matchAndRewrite(MaskedStoreOp store,
5932 PatternRewriter &rewriter)
const override {
5936 store, store.getValueToStore(), store.getBase(), store.getIndices());
5944 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5949void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5950 MLIRContext *context) {
5951 results.
add<MaskedStoreFolder>(context);
5954LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5955 SmallVectorImpl<OpFoldResult> &results) {
5959FailureOr<std::optional<SmallVector<Value>>>
5960MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
5969LogicalResult GatherOp::verify() {
5970 VectorType indVType = getIndexVectorType();
5971 VectorType maskVType = getMaskVectorType();
5973 ShapedType baseType = getBaseType();
5975 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5976 return emitOpError(
"requires base to be a memref or ranked tensor type");
5978 if (resVType.getElementType() != baseType.getElementType())
5979 return emitOpError(
"base and result element type should match");
5980 if (llvm::size(getOffsets()) != baseType.getRank())
5981 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5982 if (resVType.getShape() != indVType.getShape())
5983 return emitOpError(
"expected result dim to match indices dim");
5984 if (resVType.getShape() != maskVType.getShape())
5985 return emitOpError(
"expected result dim to match mask dim");
5986 if (resVType != getPassThruVectorType())
5987 return emitOpError(
"expected pass_thru of same type as result type");
5995Type GatherOp::getExpectedMaskType() {
5996 auto vecType = this->getIndexVectorType();
5997 return VectorType::get(vecType.getShape(),
5998 IntegerType::get(vecType.getContext(), 1),
5999 vecType.getScalableDims());
6002std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6007static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6008 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6009 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6015 DenseIntElementsAttr elements;
6020 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6024class GatherFolder final :
public OpRewritePattern<GatherOp> {
6027 LogicalResult matchAndRewrite(GatherOp gather,
6028 PatternRewriter &rewriter)
const override {
6033 rewriter.
replaceOp(gather, gather.getPassThru());
6038 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6044class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6047 LogicalResult matchAndRewrite(GatherOp op,
6048 PatternRewriter &rewriter)
const override {
6049 if (!isa<MemRefType>(op.getBase().getType()))
6052 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6056 op.getOffsets(), op.getMask(),
6063void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6064 MLIRContext *context) {
6065 results.
add<GatherFolder, FoldContiguousGather>(context);
6068FailureOr<std::optional<SmallVector<Value>>>
6069GatherOp::bubbleDownCasts(OpBuilder &builder) {
6078LogicalResult ScatterOp::verify() {
6079 VectorType indVType = getIndexVectorType();
6080 VectorType maskVType = getMaskVectorType();
6082 ShapedType baseType = getBaseType();
6084 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6085 return emitOpError(
"requires base to be a memref or ranked tensor type");
6087 if (valueVType.getElementType() != baseType.getElementType())
6088 return emitOpError(
"base and valueToStore element type should match");
6089 if (llvm::size(getOffsets()) != baseType.getRank())
6090 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6091 if (valueVType.getShape() != indVType.getShape())
6092 return emitOpError(
"expected valueToStore dim to match indices dim");
6093 if (valueVType.getShape() != maskVType.getShape())
6094 return emitOpError(
"expected valueToStore dim to match mask dim");
6098class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6101 LogicalResult matchAndRewrite(ScatterOp scatter,
6102 PatternRewriter &rewriter)
const override {
6103 ShapedType baseType = scatter.getBaseType();
6104 bool isMemRef = isa<MemRefType>(baseType);
6105 if (!isMemRef && !isa<RankedTensorType>(baseType))
6118 rewriter.
replaceOp(scatter, scatter.getBase());
6123 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6129class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6132 LogicalResult matchAndRewrite(ScatterOp op,
6133 PatternRewriter &rewriter)
const override {
6136 if (!isa<MemRefType>(op.getBase().getType()))
6139 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6143 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6149void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6150 MLIRContext *context) {
6151 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6154FailureOr<std::optional<SmallVector<Value>>>
6155ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6164LogicalResult ExpandLoadOp::verify() {
6165 VectorType maskVType = getMaskVectorType();
6166 VectorType passVType = getPassThruVectorType();
6170 if (resVType.getElementType() != memType.getElementType())
6171 return emitOpError(
"base and result element type should match");
6172 if (llvm::size(
getIndices()) != memType.getRank())
6173 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6174 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6175 return emitOpError(
"expected result dim to match mask dim");
6176 if (resVType != passVType)
6177 return emitOpError(
"expected pass_thru of same type as result type");
6182class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6185 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6186 PatternRewriter &rewriter)
const override {
6190 expand, expand.getType(), expand.getBase(), expand.getIndices());
6193 rewriter.
replaceOp(expand, expand.getPassThru());
6198 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6203void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6204 MLIRContext *context) {
6205 results.
add<ExpandLoadFolder>(context);
6208FailureOr<std::optional<SmallVector<Value>>>
6209ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6218LogicalResult CompressStoreOp::verify() {
6219 VectorType maskVType = getMaskVectorType();
6223 if (valueVType.getElementType() != memType.getElementType())
6224 return emitOpError(
"base and valueToStore element type should match");
6225 if (llvm::size(
getIndices()) != memType.getRank())
6226 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6227 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6228 return emitOpError(
"expected valueToStore dim to match mask dim");
6233class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6236 LogicalResult matchAndRewrite(CompressStoreOp compress,
6237 PatternRewriter &rewriter)
const override {
6241 compress, compress.getValueToStore(), compress.getBase(),
6242 compress.getIndices());
6250 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6255void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6256 MLIRContext *context) {
6257 results.
add<CompressStoreFolder>(context);
6260FailureOr<std::optional<SmallVector<Value>>>
6261CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6270void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6272 setResultRanges(getResult(), argRanges.front());
6275std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6276 return llvm::to_vector<4>(getResultVectorType().
getShape());
6279LogicalResult ShapeCastOp::verify() {
6281 VectorType sourceType = getSourceVectorType();
6282 VectorType resultType = getResultVectorType();
6285 if (sourceType.getElementType() != resultType.getElementType())
6286 return emitOpError(
"has different source and result element types");
6289 int64_t sourceNElms = sourceType.getNumElements();
6290 int64_t resultNElms = resultType.getNumElements();
6291 if (sourceNElms != resultNElms) {
6292 return emitOpError() <<
"has different number of elements at source ("
6293 << sourceNElms <<
") and result (" << resultNElms
6298 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6299 int64_t resultNScalableDims = resultType.getNumScalableDims();
6300 if (sourceNScalableDims != resultNScalableDims)
6301 return emitOpError() <<
"has different number of scalable dims at source ("
6302 << sourceNScalableDims <<
") and result ("
6303 << resultNScalableDims <<
")";
6312static bool isOrderPreserving(TransposeOp transpose) {
6313 ArrayRef<int64_t> permutation = transpose.getPermutation();
6314 VectorType sourceType = transpose.getSourceVectorType();
6315 ArrayRef<int64_t> inShape = sourceType.getShape();
6316 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6317 auto isNonScalableUnitDim = [&](int64_t dim) {
6318 return inShape[dim] == 1 && !inDimIsScalable[dim];
6320 int64_t current = 0;
6321 for (
auto p : permutation) {
6322 if (!isNonScalableUnitDim(p)) {
6332OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6334 VectorType resultType =
getType();
6337 if (getSource().
getType() == resultType)
6341 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6342 setOperand(precedingShapeCast.getSource());
6347 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6348 if (isOrderPreserving(transpose)) {
6349 setOperand(transpose.getVector());
6357 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6358 if (bcastOp.getSourceType() == resultType)
6359 return bcastOp.getSource();
6363 if (
auto denseAttr =
6364 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6365 return denseAttr.reshape(
getType());
6368 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6381static VectorType trimTrailingOneDims(VectorType oldType) {
6382 ArrayRef<int64_t> oldShape = oldType.getShape();
6383 ArrayRef<int64_t> newShape = oldShape;
6385 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6386 ArrayRef<bool> newScalableDims = oldScalableDims;
6388 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6389 newShape = newShape.drop_back(1);
6390 newScalableDims = newScalableDims.drop_back(1);
6395 if (newShape.empty()) {
6396 newShape = oldShape.take_back();
6397 newScalableDims = oldScalableDims.take_back();
6400 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6415class ShapeCastCreateMaskFolderTrailingOneDim final
6416 :
public OpRewritePattern<ShapeCastOp> {
6420 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6421 PatternRewriter &rewriter)
const override {
6422 Value shapeOpSrc = shapeOp->getOperand(0);
6423 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6424 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6425 if (!createMaskOp && !constantMaskOp)
6428 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6429 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6431 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6432 if (newVecType != shapeOpResTy)
6435 auto numDimsToDrop =
6436 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6443 auto maskOperands = createMaskOp.getOperands();
6444 auto numMaskOperands = maskOperands.size();
6447 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6449 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6450 if (!constant || (constant.value() != 1))
6453 SmallVector<Value> newMaskOperands =
6454 maskOperands.drop_back(numDimsToDrop);
6461 if (constantMaskOp) {
6462 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6463 auto numMaskOperands = maskDimSizes.size();
6466 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6468 if (maskDimSizes[i] != 1)
6472 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6486class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6490 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6491 PatternRewriter &rewriter)
const override {
6493 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6497 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6498 bool srcIsScalar = !srcVectorType;
6506 if (srcVectorType) {
6507 if (srcVectorType.getNumElements() ==
6508 shapeCastOp.getResultVectorType().getNumElements()) {
6510 shapeCastOp, shapeCastOp.getResultVectorType(),
6511 broadcastOp.getSource());
6522 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6524 BroadcastableToResult::Success) {
6526 shapeCastOp, dstVectorType, broadcastOp.getSource());
6535void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6536 MLIRContext *context) {
6538 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6546LogicalResult BitCastOp::verify() {
6547 auto sourceVectorType = getSourceVectorType();
6548 auto resultVectorType = getResultVectorType();
6550 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6551 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6552 return emitOpError(
"dimension size mismatch at: ") << i;
6555 DataLayout dataLayout = DataLayout::closest(*
this);
6556 auto sourceElementBits =
6558 auto resultElementBits =
6561 if (sourceVectorType.getRank() == 0) {
6562 if (sourceElementBits != resultElementBits)
6563 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6564 "types must be equal");
6565 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6566 resultElementBits * resultVectorType.getShape().back()) {
6568 "source/result bitwidth of the minor 1-D vectors must be equal");
6574OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6580 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6581 if (getResult().
getType() == otherOp.getSource().getType())
6582 return otherOp.getSource();
6584 setOperand(otherOp.getSource());
6588 Attribute sourceConstant = adaptor.getSource();
6589 if (!sourceConstant)
6592 Type srcElemType = getSourceVectorType().getElementType();
6593 Type dstElemType = getResultVectorType().getElementType();
6595 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6596 if (floatPack.isSplat()) {
6597 auto splat = floatPack.getSplatValue<FloatAttr>();
6600 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6601 uint32_t bits =
static_cast<uint32_t
>(
6602 splat.getValue().bitcastToAPInt().getZExtValue());
6604 bits = (bits << 16) | (bits & 0xffff);
6605 APInt intBits(32, bits);
6606 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6612 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6613 if (intPack.isSplat()) {
6614 auto splat = intPack.getSplatValue<IntegerAttr>();
6616 if (llvm::isa<IntegerType>(dstElemType)) {
6621 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6622 APInt intBits = splat.getValue().zext(dstBitWidth);
6625 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6626 intBits = (intBits << srcBitWidth) | intBits;
6640static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6641 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6642 SmallVector<int64_t, 8> res(memRefType.getShape());
6644 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6650void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
6652 result.addOperands(source);
6653 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6654 VectorType vectorType =
6655 VectorType::get(extractShape(memRefType),
6657 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6658 memRefType.getMemorySpace()));
6661LogicalResult TypeCastOp::verify() {
6662 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6663 if (!canonicalType.getLayout().isIdentity())
6664 return emitOpError(
"expects operand to be a memref with identity layout");
6665 if (!getResultMemRefType().getLayout().isIdentity())
6666 return emitOpError(
"expects result to be a memref with identity layout");
6667 if (getResultMemRefType().getMemorySpace() !=
6669 return emitOpError(
"expects result in same memory space");
6672 auto resultType = getResultMemRefType();
6676 "expects result and operand with same underlying scalar type: ")
6678 if (extractShape(sourceType) != extractShape(resultType))
6680 "expects concatenated result and operand shapes to be equal: ")
6689void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
6690 Value vector, ArrayRef<int64_t> permutation) {
6691 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6692 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6693 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6694 for (
unsigned i = 0; i < permutation.size(); ++i) {
6695 transposedShape[i] = vt.getShape()[permutation[i]];
6696 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6699 result.addOperands(vector);
6700 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6701 transposedScalableDims));
6702 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
6706OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6709 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6710 return splat.reshape(getResultVectorType());
6713 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6727 if (getSourceVectorType() == getResultVectorType() &&
6728 isOrderPreserving(*
this))
6734LogicalResult vector::TransposeOp::verify() {
6735 VectorType vectorType = getSourceVectorType();
6736 VectorType resultType = getResultVectorType();
6737 int64_t rank = resultType.getRank();
6738 if (vectorType.getRank() != rank)
6739 return emitOpError(
"vector result rank mismatch: ") << rank;
6741 ArrayRef<int64_t> perm = getPermutation();
6742 int64_t size = perm.size();
6744 return emitOpError(
"transposition length mismatch: ") << size;
6745 SmallVector<bool, 8> seen(rank,
false);
6746 for (
const auto &ta : llvm::enumerate(perm)) {
6747 if (ta.value() < 0 || ta.value() >= rank)
6748 return emitOpError(
"transposition index out of range: ") << ta.value();
6749 if (seen[ta.value()])
6750 return emitOpError(
"duplicate position index: ") << ta.value();
6751 seen[ta.value()] =
true;
6752 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6753 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6758std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6759 return llvm::to_vector<4>(getResultVectorType().
getShape());
6762void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6764 setResultRanges(getResult(), argRanges.front());
6770class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6774 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6775 PatternRewriter &rewriter)
const override {
6777 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6778 ArrayRef<int64_t> permutation2) {
6779 SmallVector<int64_t, 4>
result;
6780 for (
auto index : permutation2)
6781 result.push_back(permutation1[index]);
6786 vector::TransposeOp parentTransposeOp =
6787 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6788 if (!parentTransposeOp)
6791 SmallVector<int64_t, 4> permutation = composePermutations(
6792 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6795 transposeOp, transposeOp.getResult().
getType(),
6796 parentTransposeOp.getVector(), permutation);
6802class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
6806 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6807 PatternRewriter &rewriter)
const override {
6808 Value splat = getScalarSplatSource(transposeOp.getVector());
6813 transposeOp, transposeOp.getResultVectorType(), splat);
6819class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6823 LogicalResult matchAndRewrite(TransposeOp transpOp,
6824 PatternRewriter &rewriter)
const override {
6825 Value transposeSrc = transpOp.getVector();
6826 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6827 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6828 if (!createMaskOp && !constantMaskOp)
6833 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6836 auto maskOperands = createMaskOp.getOperands();
6837 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6841 transpOp, transpOp.getResultVectorType(), newOperands);
6846 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6850 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6856class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6860 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6861 PatternRewriter &rewriter)
const override {
6863 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6866 if (!isOrderPreserving(transposeOp))
6869 VectorType resultType = transposeOp.getType();
6876 shapeCastOp.getSource());
6895class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
6898 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6899 PatternRewriter &rewriter)
const override {
6900 auto fromElementsOp =
6901 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6902 if (!fromElementsOp)
6905 VectorType srcTy = fromElementsOp.getDest().getType();
6906 VectorType dstTy = transposeOp.getType();
6908 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6909 int64_t rank = srcTy.getRank();
6912 SmallVector<int64_t> inversePerm(rank, 0);
6913 for (int64_t i = 0; i < rank; ++i)
6914 inversePerm[permutation[i]] = i;
6916 ArrayRef<int64_t> srcShape = srcTy.getShape();
6917 ArrayRef<int64_t> dstShape = dstTy.getShape();
6918 SmallVector<int64_t> srcIdx(rank, 0);
6919 SmallVector<int64_t> dstIdx(rank, 0);
6923 auto elementsOld = fromElementsOp.getElements();
6924 SmallVector<Value> elementsNew;
6925 int64_t dstNumElements = dstTy.getNumElements();
6926 elementsNew.reserve(dstNumElements);
6930 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
6934 for (int64_t j = 0; j < rank; ++j)
6935 srcIdx[j] = dstIdx[inversePerm[j]];
6937 int64_t srcLin =
linearize(srcIdx, srcStrides);
6939 elementsNew.push_back(elementsOld[srcLin]);
6973class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
6976 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6977 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6979 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6980 PatternRewriter &rewriter)
const override {
6986 "not preceded by a broadcast");
6989 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
6990 VectorType outputType = transpose.getResultVectorType();
6993 bool inputIsScalar = !inputType;
6994 if (inputIsScalar) {
7000 ArrayRef<int64_t> permutation = transpose.getPermutation();
7001 ArrayRef<int64_t> inputShape = inputType.getShape();
7002 int64_t inputRank = inputType.getRank();
7003 int64_t outputRank = transpose.getType().getRank();
7004 int64_t deltaRank = outputRank - inputRank;
7007 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7008 bool notOne = inputShape[inputIndex] != 1;
7009 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7010 bool groupEndFound = notOne || prevNotOne;
7011 if (groupEndFound) {
7012 int high = inputIndex + deltaRank;
7016 for (
int i = low; i < high; ++i) {
7017 if (permutation[i] < low || permutation[i] >= high) {
7019 transpose,
"permutation not local to group");
7033 vector::BroadcastableToResult::Success &&
7034 "not broadcastable directly to transpose output");
7045void vector::TransposeOp::getCanonicalizationPatterns(
7046 RewritePatternSet &results, MLIRContext *context) {
7047 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7048 FoldTransposeSplat, FoldTransposeFromElements,
7049 FoldTransposeBroadcast>(context);
7056void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7058 assert(kind == ConstantMaskKind::AllTrue ||
7059 kind == ConstantMaskKind::AllFalse);
7060 build(builder,
result, type,
7061 kind == ConstantMaskKind::AllTrue
7063 : SmallVector<int64_t>(type.getRank(), 0));
7066LogicalResult ConstantMaskOp::verify() {
7067 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7069 if (resultType.getRank() == 0) {
7070 if (getMaskDimSizes().size() != 1)
7071 return emitError(
"array attr must have length 1 for 0-D vectors");
7072 auto dim = getMaskDimSizes()[0];
7073 if (dim != 0 && dim != 1)
7074 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7079 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7081 "must specify array attr of size equal vector result rank");
7084 auto resultShape = resultType.getShape();
7085 auto resultScalableDims = resultType.getScalableDims();
7086 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7087 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7088 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7090 "array attr of size out of bounds of vector result dimension size");
7091 if (resultScalableDims[index] && maskDimSize != 0 &&
7092 maskDimSize != resultShape[index])
7094 "only supports 'none set' or 'all set' scalable dimensions");
7098 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7099 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7100 if (anyZeros && !allZeros)
7101 return emitOpError(
"expected all mask dim sizes to be zeros, "
7102 "as a result of conjunction with zero mask dim");
7106bool ConstantMaskOp::isAllOnesMask() {
7109 if (resultType.getRank() == 0) {
7110 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7111 return getMaskDimSizes()[0] == 1;
7113 for (
const auto [resultSize, maskDimSize] :
7114 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7115 if (maskDimSize < resultSize)
7121OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7122 ArrayRef<int64_t> bounds = getMaskDimSizes();
7125 auto createBoolSplat = [&](
bool x) {
7131 if (vectorSizes.empty()) {
7132 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7133 return createBoolSplat(bounds[0] == 1);
7136 if (bounds == vectorSizes)
7137 return createBoolSplat(
true);
7138 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7139 return createBoolSplat(
false);
7140 return OpFoldResult();
7147void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7149 ArrayRef<OpFoldResult> mixedOperands) {
7150 SmallVector<Value> operands =
7152 build(builder,
result, type, operands);
7155LogicalResult CreateMaskOp::verify() {
7156 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7158 if (vectorType.getRank() == 0) {
7159 if (getNumOperands() != 1)
7161 "must specify exactly one operand for 0-D create_mask");
7162 }
else if (getNumOperands() !=
7163 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7165 "must specify an operand for each result vector dimension");
7195class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7199 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7200 PatternRewriter &rewriter)
const override {
7201 VectorType maskType = createMaskOp.getVectorType();
7202 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7203 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7206 constexpr std::array<int64_t, 1> rankZeroShape{1};
7207 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7208 if (maskType.getRank() == 0) {
7209 maskTypeDimSizes = rankZeroShape;
7210 maskTypeDimScalableFlags = rankZeroScalableDims;
7215 SmallVector<int64_t, 4> constantDims;
7216 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7221 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7223 constantDims.push_back(*intSize);
7227 if (vscaleMultiplier < maskTypeDimSizes[i])
7229 constantDims.push_back(*vscaleMultiplier);
7236 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7237 value = std::clamp<int64_t>(value, 0, maskDimSize);
7240 if (llvm::is_contained(constantDims, 0))
7241 constantDims.assign(constantDims.size(), 0);
7252void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7253 MLIRContext *context) {
7254 results.
add<CreateMaskFolder>(context);
7262 OpBuilder &builder, OperationState &
result, Value mask,
7263 Operation *maskableOp,
7264 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7265 assert(maskRegionBuilder &&
7266 "builder callback for 'maskRegion' must be present");
7268 result.addOperands(mask);
7269 OpBuilder::InsertionGuard guard(builder);
7270 Region *maskRegion =
result.addRegion();
7272 maskRegionBuilder(builder, maskableOp);
7277 Value mask, Operation *maskableOp,
7278 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7279 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7285 Value mask, Value passthru, Operation *maskableOp,
7286 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7287 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7289 result.addOperands(passthru);
7290 result.addTypes(resultTypes);
7293ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7295 result.regions.reserve(1);
7296 Region &maskRegion = *
result.addRegion();
7301 OpAsmParser::UnresolvedOperand mask;
7306 OpAsmParser::UnresolvedOperand passthru;
7308 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7315 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7326 SmallVector<Type> resultTypes;
7329 result.types.append(resultTypes);
7335 if (parsePassthru.succeeded()) {
7336 if (resultTypes.empty())
7339 "expects a result if passthru operand is provided");
7348void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7349 p <<
" " << getMask();
7351 p <<
", " << getPassthru();
7355 Block *singleBlock = &getMaskRegion().getBlocks().front();
7362 p <<
" : " << getMask().getType();
7363 if (getNumResults() > 0)
7364 p <<
" -> " << getResultTypes();
7367void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7370 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7371 MaskOp>::ensureTerminator(region, builder, loc);
7377 if (isa<vector::YieldOp>(block.
back()))
7385 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7386 MaskOp>::ensureTerminator(region, builder, loc);
7392 Operation *maskedOp = &block.
front();
7393 opBuilder.setInsertionPointToEnd(&block);
7394 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7397LogicalResult MaskOp::verify() {
7399 Block &block = getMaskRegion().getBlocks().
front();
7401 return emitOpError(
"expects a terminator within the mask region");
7404 if (numMaskRegionOps > 2)
7405 return emitOpError(
"expects only one operation to mask");
7408 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7410 return emitOpError(
"expects a terminator within the mask region");
7412 if (terminator->getNumOperands() != getNumResults())
7414 "expects number of results to match mask region yielded values");
7417 if (numMaskRegionOps == 1)
7420 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7422 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7426 return emitOpError(
"expects number of results to match maskable operation "
7427 "number of results");
7429 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7430 return emitOpError(
"expects all the results from the MaskableOpInterface "
7431 "to match all the values returned by the terminator");
7433 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7435 "expects result type to match maskable operation result type");
7438 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7439 return emitOpError(
"multiple vector results not supported");
7442 Type expectedMaskType = maskableOp.getExpectedMaskType();
7443 if (getMask().
getType() != expectedMaskType)
7445 << expectedMaskType <<
" mask for the maskable operation";
7448 Value passthru = getPassthru();
7450 if (!maskableOp.supportsPassthru())
7452 "doesn't expect a passthru argument for this maskable operation");
7455 return emitOpError(
"expects result when passthru argument is provided");
7458 return emitOpError(
"expects passthru type to match result type");
7478static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7479 SmallVectorImpl<OpFoldResult> &results) {
7480 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7483 Block *block = maskOp.getMaskBlock();
7484 auto terminator = cast<vector::YieldOp>(block->
front());
7485 if (terminator.getNumOperands() == 0) {
7491 llvm::append_range(results, terminator.getOperands());
7495LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7496 SmallVectorImpl<OpFoldResult> &results) {
7497 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7505 Operation *maskableOp = getMaskableOp();
7509 llvm::append_range(results, maskableOp->
getResults());
7525class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
7528 LogicalResult matchAndRewrite(MaskOp maskOp,
7529 PatternRewriter &rewriter)
const override {
7530 if (!maskOp.isEmpty())
7533 if (!maskOp.hasPassthru())
7536 Block *block = maskOp.getMaskBlock();
7537 auto terminator = cast<vector::YieldOp>(block->
front());
7538 assert(terminator.getNumOperands() == 1 &&
7539 "expected one result when passthru is provided");
7542 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7543 terminator.getOperand(0), maskOp.getPassthru());
7549void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7550 MLIRContext *context) {
7551 results.
add<CanonializeEmptyMaskOp>(context);
7557Operation *MaskOp::getMaskableOp() {
7558 Block *block = getMaskBlock();
7562 return &block->
front();
7566bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
7572LogicalResult ScanOp::verify() {
7573 VectorType srcType = getSourceType();
7574 VectorType initialType = getInitialValueType();
7576 int64_t srcRank = srcType.getRank();
7577 int64_t reductionDim = getReductionDim();
7578 if (reductionDim >= srcRank)
7580 << reductionDim <<
" has to be less than " << srcRank;
7583 int64_t initialValueRank = initialType.getRank();
7584 if (initialValueRank != srcRank - 1)
7586 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7589 ArrayRef<int64_t> srcShape = srcType.getShape();
7590 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7591 SmallVector<int64_t> expectedShape;
7592 for (
int i = 0; i < srcRank; i++) {
7593 if (i != reductionDim)
7594 expectedShape.push_back(srcShape[i]);
7596 if (!llvm::equal(initialValueShapes, expectedShape)) {
7597 return emitOpError(
"incompatible input/initial value shapes");
7601 Type eltType = getDestType().getElementType();
7604 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7611 RewritePatternSet &
patterns, PatternBenefit benefit) {
7613 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7614 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7615 StridedSliceConstantMaskFolder, TransposeFolder>(
7620 CombiningKind kind, Value v1, Value acc,
7621 arith::FastMathFlagsAttr fastmath,
7628 case CombiningKind::ADD:
7630 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
7631 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7632 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7634 llvm_unreachable(
"invalid value types for ADD reduction");
7636 case CombiningKind::AND:
7638 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
7640 case CombiningKind::MAXNUMF:
7641 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7642 "expected float values");
7643 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7645 case CombiningKind::MAXIMUMF:
7646 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7647 "expected float values");
7648 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7650 case CombiningKind::MINNUMF:
7651 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7652 "expected float values");
7653 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7655 case CombiningKind::MINIMUMF:
7656 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7657 "expected float values");
7658 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7660 case CombiningKind::MAXSI:
7662 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7664 case CombiningKind::MINSI:
7666 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7668 case CombiningKind::MAXUI:
7670 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7672 case CombiningKind::MINUI:
7674 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7676 case CombiningKind::MUL:
7678 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
7679 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7680 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7682 llvm_unreachable(
"invalid value types for MUL reduction");
7684 case CombiningKind::OR:
7686 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
7688 case CombiningKind::XOR:
7690 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7694 assert(
result &&
"unknown CombiningKind");
7702void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7704 auto resultType = cast<VectorType>(
getType());
7705 if (resultType.isScalable()) {
7709 APInt zero(bitwidth, 0);
7710 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7711 ConstantIntRanges
result = {zero, high, zero, high};
7712 setResultRanges(getResult(),
result);
7742struct StepCompareFolder :
public OpRewritePattern<StepOp> {
7745 LogicalResult matchAndRewrite(StepOp stepOp,
7746 PatternRewriter &rewriter)
const override {
7747 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7749 for (OpOperand &use : stepOp.getResult().getUses()) {
7750 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7755 const unsigned stepOperandNumber = use.getOperandNumber();
7756 if (stepOperandNumber != 0)
7760 unsigned constOperandNumber = 1;
7761 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7762 std::optional<int64_t> maybeConstValue =
7764 if (!maybeConstValue.has_value())
7767 int64_t constValue = maybeConstValue.value();
7768 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7770 auto maybeSplat = [&]() -> std::optional<bool> {
7772 if ((pred == arith::CmpIPredicate::ult ||
7773 pred == arith::CmpIPredicate::uge) &&
7774 stepSize <= constValue)
7775 return pred == arith::CmpIPredicate::ult;
7778 if ((pred == arith::CmpIPredicate::ule ||
7779 pred == arith::CmpIPredicate::ugt) &&
7780 stepSize - 1 <= constValue) {
7781 return pred == arith::CmpIPredicate::ule;
7785 if ((pred == arith::CmpIPredicate::eq ||
7786 pred == arith::CmpIPredicate::ne) &&
7787 stepSize <= constValue)
7788 return pred == arith::CmpIPredicate::ne;
7790 return std::nullopt;
7793 if (!maybeSplat.has_value())
7798 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7803 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7815void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7816 MLIRContext *context) {
7817 results.
add<StepCompareFolder>(context);
7827 Operation *maskableOp) {
7828 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7840 Operation *maskableOp, Value mask,
7845 return MaskOp::create(builder, maskableOp->
getLoc(),
7848 return MaskOp::create(builder, maskableOp->
getLoc(),
7861 Value newValue, Value passthru) {
7865 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
7866 mask, newValue, passthru);
7873#define GET_ATTRDEF_CLASSES
7874#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7876#define GET_OP_CLASSES
7877#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)