42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/Casting.h"
54#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
56#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
77 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
79 for (
bool b : denseElts.getValues<
bool>())
82 else if (!
b && val <= 0)
96 auto shape = m.getType().getShape();
99 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
100 if (maskIdx < dimSize)
113 auto maskOperands = m.getOperands();
114 for (
Value operand : maskOperands) {
115 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
117 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
130 vector::YieldOp::create(builder, loc);
136 switch (combiningKind) {
137 case CombiningKind::ADD:
138 case CombiningKind::MUL:
140 case CombiningKind::MINUI:
141 case CombiningKind::MINSI:
142 case CombiningKind::MAXUI:
143 case CombiningKind::MAXSI:
144 case CombiningKind::AND:
145 case CombiningKind::OR:
146 case CombiningKind::XOR:
148 case CombiningKind::MINNUMF:
149 case CombiningKind::MAXNUMF:
150 case CombiningKind::MINIMUMF:
151 case CombiningKind::MAXIMUMF:
152 return llvm::isa<FloatType>(elementType);
182 VectorType vectorType) {
183 unsigned elementVectorRank = 0;
184 VectorType elementVectorType =
185 llvm::dyn_cast<VectorType>(shapedType.getElementType());
186 if (elementVectorType)
187 elementVectorRank += elementVectorType.getRank();
188 return vectorType.getRank() - elementVectorRank;
192 VectorType vectorType) {
195 if (shapedType.getRank() == 0 &&
201 shapedType.getRank(),
203 shapedType.getContext());
210 vector::TransferReadOp read) {
211 auto readMask = read.getMask();
212 auto writeMask = write.getMask();
218 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
219 if (!couldBeSameSplat)
236 vector::TransferReadOp read) {
237 return !defWrite.hasOutOfBoundsDim() &&
238 defWrite.getIndices() == read.getIndices() &&
239 defWrite.getVectorType() == read.getVectorType() &&
240 defWrite.getPermutationMap() == read.getPermutationMap() &&
241 ((!defWrite.getMask() && !read.getMask()) ||
246 vector::TransferWriteOp priorWrite) {
247 return priorWrite.getIndices() == write.getIndices() &&
248 priorWrite.getMask() == write.getMask() &&
249 priorWrite.getVectorType() == write.getVectorType() &&
250 priorWrite.getPermutationMap() == write.getPermutationMap();
254 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
255 bool testDynamicValueUsingBounds) {
257 if (transferA.getVectorType() != transferB.getVectorType())
259 unsigned rankOffset = transferA.getLeadingShapedRank();
260 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
261 Value indexA = transferA.getIndices()[i];
262 Value indexB = transferB.getIndices()[i];
266 if (i < rankOffset) {
269 if (cstIndexA.has_value() && cstIndexB.has_value()) {
270 if (*cstIndexA != *cstIndexB)
274 if (testDynamicValueUsingBounds) {
277 FailureOr<uint64_t> delta =
279 if (succeeded(delta) && *delta != 0)
282 FailureOr<bool> testEqual =
284 if (succeeded(testEqual) && !testEqual.value())
290 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
291 if (cstIndexA.has_value() && cstIndexB.has_value()) {
292 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
293 if (distance >= vectorDim)
297 if (testDynamicValueUsingBounds) {
300 FailureOr<int64_t> delta =
302 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
305 FailureOr<int64_t> computeDelta =
307 if (succeeded(computeDelta)) {
308 if (std::abs(computeDelta.value()) >= vectorDim)
318 VectorTransferOpInterface transferB,
319 bool testDynamicValueUsingBounds) {
320 if (transferA.getBase() != transferB.getBase())
323 testDynamicValueUsingBounds);
333 for (
auto [posInDim, dimSize, offsetInDim] :
334 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
336 if (posInDim < dimSize + offsetInDim)
340 posInDim = offsetInDim;
350 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
352 assert(constOp &&
"Unexpected non-constant index");
353 return constOp.value();
363 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
364 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
365 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
375 llvm::transform(foldResults, std::back_inserter(values),
377 if (
auto attr = dyn_cast<Attribute>(foldResult))
379 builder, loc, cast<IntegerAttr>(attr).getInt())
382 return cast<Value>(foldResult);
395 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
397 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
407 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
408 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
409 if (intAttr.getType() != expectedType)
410 return IntegerAttr::get(expectedType, intAttr.getInt());
416 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
417 auto intType = dyn_cast<IntegerType>(expectedType);
421 APFloat floatVal = floatAttr.getValue();
422 APInt intVal = floatVal.bitcastToAPInt();
423 return IntegerAttr::get(expectedType, intVal);
462struct VectorInlinerInterface :
public DialectInlinerInterface {
463 using DialectInlinerInterface::DialectInlinerInterface;
472void VectorDialect::initialize() {
474#define GET_ATTRDEF_LIST
475#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
480#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
483 addInterfaces<VectorInlinerInterface>();
485 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
486 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
488 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
490 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
491 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
492 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
503 return arith::ConstantOp::materialize(builder, value, type, loc);
519void vector::MultiDimReductionOp::build(
OpBuilder &builder,
522 CombiningKind kind) {
524 for (
const auto &en : llvm::enumerate(reductionMask))
526 reductionDims.push_back(en.index());
527 build(builder,
result, kind, source,
acc, reductionDims);
530OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
532 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
537std::optional<SmallVector<int64_t, 4>>
538MultiDimReductionOp::getShapeForUnroll() {
539 return llvm::to_vector<4>(getSourceVectorType().
getShape());
542LogicalResult MultiDimReductionOp::verify() {
545 Type inferredReturnType;
546 auto sourceScalableDims = getSourceVectorType().getScalableDims();
547 for (
auto [dimIdx, dimSize] :
548 llvm::enumerate(getSourceVectorType().
getShape()))
549 if (!llvm::any_of(getReductionDims(),
550 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
551 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
553 targetShape.push_back(dimSize);
554 scalableDims.push_back(sourceScalableDims[dimIdx]);
557 if (targetShape.empty())
558 inferredReturnType = getSourceVectorType().getElementType();
560 inferredReturnType = VectorType::get(
561 targetShape, getSourceVectorType().
getElementType(), scalableDims);
562 if (
getType() != inferredReturnType)
564 <<
" is incompatible with source type "
565 << getSourceVectorType();
571Type MultiDimReductionOp::getExpectedMaskType() {
572 auto vecType = getSourceVectorType();
573 return VectorType::get(vecType.getShape(),
574 IntegerType::get(vecType.getContext(), 1),
575 vecType.getScalableDims());
584struct ElideUnitDimsInMultiDimReduction
588 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
589 PatternRewriter &rewriter)
const override {
590 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
591 for (
const auto &dim :
enumerate(shape)) {
592 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
597 OpBuilder::InsertionGuard guard(rewriter);
600 if (reductionOp.isMasked()) {
602 rootOp = reductionOp.getMaskingOp();
603 mask = reductionOp.getMaskingOp().getMask();
605 rootOp = reductionOp;
608 Location loc = reductionOp.getLoc();
609 Value acc = reductionOp.getAcc();
611 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
613 VectorType newMaskType =
614 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
615 dstVecType.getScalableDims());
616 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
618 cast = vector::ShapeCastOp::create(
619 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
624 mask = vector::ExtractOp::create(rewriter, loc, mask);
625 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
630 cast,
nullptr, mask);
637void MultiDimReductionOp::getCanonicalizationPatterns(
639 results.
add<ElideUnitDimsInMultiDimReduction>(context);
648 arith::FastMathFlags fastMathFlags) {
654 arith::FastMathFlags fastMathFlags) {
656 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
660LogicalResult ReductionOp::verify() {
662 int64_t rank = getSourceVectorType().getRank();
664 return emitOpError(
"unsupported reduction rank: ") << rank;
667 Type eltType = getDest().getType();
670 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
679Type ReductionOp::getExpectedMaskType() {
680 auto vecType = getSourceVectorType();
681 return VectorType::get(vecType.getShape(),
682 IntegerType::get(vecType.getContext(), 1),
683 vecType.getScalableDims());
690 case arith::AtomicRMWKind::addf:
691 case arith::AtomicRMWKind::addi:
692 return vector::ReductionOp::create(builder,
vector.getLoc(),
693 CombiningKind::ADD,
vector);
694 case arith::AtomicRMWKind::mulf:
695 case arith::AtomicRMWKind::muli:
696 return vector::ReductionOp::create(builder,
vector.getLoc(),
697 CombiningKind::MUL,
vector);
698 case arith::AtomicRMWKind::minimumf:
699 return vector::ReductionOp::create(builder,
vector.getLoc(),
700 CombiningKind::MINIMUMF,
vector);
701 case arith::AtomicRMWKind::mins:
702 return vector::ReductionOp::create(builder,
vector.getLoc(),
703 CombiningKind::MINSI,
vector);
704 case arith::AtomicRMWKind::minu:
705 return vector::ReductionOp::create(builder,
vector.getLoc(),
706 CombiningKind::MINUI,
vector);
707 case arith::AtomicRMWKind::maximumf:
708 return vector::ReductionOp::create(builder,
vector.getLoc(),
709 CombiningKind::MAXIMUMF,
vector);
710 case arith::AtomicRMWKind::maxs:
711 return vector::ReductionOp::create(builder,
vector.getLoc(),
712 CombiningKind::MAXSI,
vector);
713 case arith::AtomicRMWKind::maxu:
714 return vector::ReductionOp::create(builder,
vector.getLoc(),
715 CombiningKind::MAXUI,
vector);
716 case arith::AtomicRMWKind::andi:
717 return vector::ReductionOp::create(builder,
vector.getLoc(),
718 CombiningKind::AND,
vector);
719 case arith::AtomicRMWKind::ori:
720 return vector::ReductionOp::create(builder,
vector.getLoc(),
721 CombiningKind::OR,
vector);
722 case arith::AtomicRMWKind::minnumf:
723 return vector::ReductionOp::create(builder,
vector.getLoc(),
724 CombiningKind::MINNUMF,
vector);
725 case arith::AtomicRMWKind::maxnumf:
726 return vector::ReductionOp::create(builder,
vector.getLoc(),
727 CombiningKind::MAXNUMF,
vector);
728 case arith::AtomicRMWKind::xori:
729 return vector::ReductionOp::create(builder,
vector.getLoc(),
730 CombiningKind::XOR,
vector);
738std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
739 return llvm::to_vector<4>(getSourceVectorType().
getShape());
746 LogicalResult matchAndRewrite(ReductionOp reductionOp,
751 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
754 if (maskableOp.isMasked()) {
756 rootOp = maskableOp.getMaskingOp();
757 mask = maskableOp.getMaskingOp().getMask();
759 rootOp = reductionOp;
762 auto vectorType = reductionOp.getSourceVectorType();
763 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
766 Location loc = reductionOp.getLoc();
768 mask = ExtractOp::create(rewriter, loc, mask);
769 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
771 if (Value acc = reductionOp.getAcc())
774 reductionOp.getFastmathAttr(), mask);
784 results.
add<ElideSingleElementReduction>(context);
798 getIndexingMapsAttrName(
result.name),
802 getIteratorTypesAttrName(
result.name),
805 return IteratorTypeAttr::get(builder.getContext(), t);
814 ContractionOp::getDefaultKind());
820 ArrayAttr iteratorTypes, CombiningKind kind) {
823 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
824 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
826 CombiningKindAttr::get(builder.
getContext(), kind));
837 DictionaryAttr dictAttr;
851 result.attributes.append(dictAttr.getValue().begin(),
852 dictAttr.getValue().end());
858 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
859 result.attributes.get(getIteratorTypesAttrName(
result.name)));
860 if (!iteratorTypes) {
862 <<
"expected " << getIteratorTypesAttrName(
result.name)
863 <<
" array attribute";
868 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
869 auto maybeIteratorType = symbolizeIteratorType(s);
870 if (!maybeIteratorType.has_value())
871 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
873 iteratorTypeAttrs.push_back(
874 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
876 result.attributes.set(getIteratorTypesAttrName(
result.name),
879 if (!
result.attributes.get(getKindAttrName(
result.name))) {
881 getKindAttrName(
result.name),
882 CombiningKindAttr::get(
result.getContext(),
883 ContractionOp::getDefaultKind()));
885 if (masksInfo.empty())
887 if (masksInfo.size() != 2)
889 "expected zero or exactly 2 vector mask operands");
890 auto lhsType = llvm::cast<VectorType>(types[0]);
891 auto rhsType = llvm::cast<VectorType>(types[1]);
893 std::array<VectorType, 2> maskTypes = {
903 auto attrNames = getTraitAttrNames();
905 traitAttrsSet.insert_range(attrNames);
907 for (
auto attr : (*this)->getAttrs()) {
908 if (attr.getName() == getIteratorTypesAttrName()) {
910 llvm::cast<ArrayAttr>(attr.getValue())
911 .getAsValueRange<IteratorTypeAttr, IteratorType>();
917 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) ->
Attribute {
918 return StringAttr::get(
getContext(), stringifyIteratorType(t));
921 attrs.emplace_back(getIteratorTypesAttrName(),
922 ArrayAttr::get(
getContext(), iteratorTypeNames));
923 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
924 attrs.push_back(attr);
927 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
928 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
929 p << getRhs() <<
", " << getAcc();
932 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
937 const std::vector<std::pair<int64_t, int64_t>> &map) {
938 for (
auto &dimPair : map) {
939 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
940 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
941 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
948 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
950 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
951 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
954 for (
auto &dimPair : contractingDimMap) {
955 lhsContractingDimSet.insert(dimPair.first);
956 rhsContractingDimSet.insert(dimPair.second);
959 llvm::make_second_range(batchDimMap));
963 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
964 if (lhsContractingDimSet.count(i) > 0)
966 expectedResultDims.push_back(lhsType.getDimSize(i));
970 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
971 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
973 expectedResultDims.push_back(rhsType.getDimSize(i));
977 if (expectedResultDims.empty()) {
979 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
980 return op.emitOpError(
"invalid accumulator/result vector shape");
983 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
984 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
985 if (!resVectorType || !accVectorType)
986 return op.emitOpError(
"invalid accumulator/result vector shape");
992 AffineMap lhsMap = op.getIndexingMapsArray()[0];
993 AffineMap rhsMap = op.getIndexingMapsArray()[1];
995 return op.emitOpError(
996 "expected all dimensions to be either a LHS or a RHS dimension");
999 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1000 VectorType v = pair.first;
1001 auto map = pair.second;
1002 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1003 unsigned pos = map.getDimPosition(idx);
1008 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1009 return op.emitOpError(
"expected all dimensions to get an extent as "
1010 "either a LHS or a RHS dimension");
1012 AffineMap resMap = op.getIndexingMapsArray()[2];
1017 assert(llvm::all_of(expectedMap.
getResults(),
1018 llvm::IsaPred<AffineConstantExpr>) &&
1019 "expected constant extent along all dimensions.");
1021 auto expectedShape =
1023 return cast<AffineConstantExpr>(e).getValue();
1026 VectorType::get(expectedShape, resVectorType.getElementType(),
1027 resVectorType.getScalableDims());
1028 if (resVectorType != expected || accVectorType != expected)
1029 return op.emitOpError(
1030 "invalid accumulator/result vector shape, expected: ")
1036LogicalResult ContractionOp::verify() {
1037 VectorType lhsType = getLhsType();
1038 VectorType rhsType = getRhsType();
1039 Type accType = getAccType();
1040 Type resType = getResultType();
1042 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1043 if (!lhsType.getElementType().isSignlessInteger())
1044 return emitOpError(
"only supports signless integer types");
1048 if (getIndexingMapsArray().size() != 3)
1049 return emitOpError(
"expected an indexing map for each vector operand");
1054 unsigned numIterators = getIteratorTypes().getValue().size();
1055 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1056 auto index = it.index();
1057 auto map = it.value();
1058 if (map.getNumSymbols() != 0)
1060 <<
index <<
" to have no symbols";
1061 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1062 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1065 if (map.getNumDims() != numIterators)
1067 <<
index <<
" to have " << numIterators <<
" number of inputs";
1068 if (map.getNumResults() != rank)
1070 <<
index <<
" to have " << rank <<
" number of outputs";
1071 if (!map.isProjectedPermutation())
1073 <<
index <<
" to be a projected permutation of its inputs";
1076 auto contractingDimMap = getContractingDimMap();
1077 auto batchDimMap = getBatchDimMap();
1080 if (contractingDimMap.empty())
1081 return emitOpError(
"expected at least one contracting dimension pair");
1084 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1085 return emitOpError(
"invalid contracting dimension map");
1089 return emitOpError(
"invalid batch dimension map");
1093 contractingDimMap, batchDimMap)))
1096 if (!getKindAttr()) {
1097 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1098 "'vector.kind<add>')");
1102 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1103 auto elementType = vectorType ? vectorType.getElementType() : resType;
1105 return emitOpError(
"unsupported contraction type");
1108 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1115Type ContractionOp::getExpectedMaskType() {
1116 auto indexingMaps = this->getIndexingMapsArray();
1119 VectorType lhsType = this->getLhsType();
1120 VectorType rhsType = this->getRhsType();
1122 unsigned numVecDims = lhsIdxMap.
getNumDims();
1128 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1131 lhsType.getScalableDims()[dimIdx];
1133 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1136 rhsType.getScalableDims()[dimIdx];
1139 assert(ShapedType::isStaticShape(maskShape) &&
1140 "Mask shape couldn't be computed");
1142 return VectorType::get(maskShape,
1143 IntegerType::get(lhsType.getContext(), 1),
1144 maskShapeScalableDims);
1149 getIteratorTypesAttrName(), getKindAttrName()};
1159static std::vector<std::pair<int64_t, int64_t>>
1161 IteratorType targetIteratorType,
MLIRContext *context) {
1162 std::vector<std::pair<int64_t, int64_t>> dimMap;
1163 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1164 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1165 if (iteratorType != targetIteratorType)
1171 if (lhsDim >= 0 && rhsDim >= 0)
1172 dimMap.emplace_back(lhsDim, rhsDim);
1177void ContractionOp::getIterationBounds(
1179 auto lhsShape = getLhsType().getShape();
1180 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1182 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1185 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1186 if (iteratorType == IteratorType::reduction) {
1189 assert(lhsDimIndex >= 0);
1190 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1195 assert(resDimIndex >= 0);
1196 assert(resVectorType !=
nullptr);
1197 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1201void ContractionOp::getIterationIndexMap(
1203 unsigned numMaps = getIndexingMapsArray().size();
1204 iterationIndexMap.resize(numMaps);
1205 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1206 auto index = it.index();
1207 auto map = it.value();
1208 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1209 auto dim = cast<AffineDimExpr>(map.getResult(i));
1210 iterationIndexMap[
index][dim.getPosition()] = i;
1215std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1217 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1221std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1223 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1227std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1229 getIterationBounds(
shape);
1251template <
typename AddOpType>
1257 auto canonicalize = [&](
Value maybeContraction,
1258 Value otherOperand) -> vector::ContractionOp {
1259 vector::ContractionOp contractionOp =
1260 dyn_cast_or_null<vector::ContractionOp>(
1263 return vector::ContractionOp();
1264 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1265 contractionOp.getAcc().getDefiningOp())) {
1266 if (maybeZero.getValue() ==
1267 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1269 bvm.
map(contractionOp.getAcc(), otherOperand);
1270 auto newContraction =
1271 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1272 rewriter.
replaceOp(addOp, newContraction.getResult());
1273 return newContraction;
1276 return vector::ContractionOp();
1279 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1280 vector::ContractionOp
contract = canonicalize(a,
b);
1305 setResultRanges(getResult(), argRanges.front());
1310 auto vectorTy = cast<VectorType>(source.
getType());
1335 build(builder,
result, source, dynamicPos,
1340ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1341 ExtractOp::Adaptor adaptor,
1343 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1344 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1345 vectorType.getRank()) {
1346 inferredReturnTypes.push_back(vectorType.getElementType());
1348 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1349 vectorType.getRank());
1350 inferredReturnTypes.push_back(VectorType::get(
1351 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1352 vectorType.getScalableDims().drop_front(n)));
1360 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1361 return vectorType && vectorType.getShape().equals({1}) &&
1362 vectorType.getElementType() == r.front();
1364 if (l.size() == 1 && r.size() == 1 &&
1365 (isCompatible(l, r) || isCompatible(r, l)))
1370LogicalResult vector::ExtractOp::verify() {
1371 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1372 if (resTy.getRank() == 0)
1374 "expected a scalar instead of a 0-d vector as the result type");
1377 auto dynamicMarkersCount =
1378 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1379 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1381 "mismatch between dynamic and static positions (kDynamic marker but no "
1382 "corresponding dynamic position) -- this can only happen due to an "
1383 "incorrect fold/rewrite");
1384 auto position = getMixedPosition();
1385 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1387 "expected position attribute of rank no greater than vector rank");
1388 for (
auto [idx, pos] : llvm::enumerate(position)) {
1389 if (
auto attr = dyn_cast<Attribute>(pos)) {
1390 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1392 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1393 return emitOpError(
"expected position attribute #")
1395 <<
" to be a non-negative integer smaller than the "
1396 "corresponding vector dimension or poison (-1)";
1403template <
typename IntType>
1405 return llvm::map_to_vector<4>(
1406 arrayAttr.getAsRange<IntegerAttr>(),
1407 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1413 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1417 if (extractOp.hasDynamicPosition())
1421 ExtractOp currentOp = extractOp;
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1427 if (currentOp.hasDynamicPosition())
1430 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1432 extractOp.setOperand(0, currentOp.getSource());
1435 std::reverse(globalPosition.begin(), globalPosition.end());
1436 extractOp.setStaticPosition(globalPosition);
1448class ExtractFromInsertTransposeChainState {
1450 ExtractFromInsertTransposeChainState(ExtractOp e);
1459 template <
typename ContainerA,
typename ContainerB>
1460 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1461 return a.size() <=
b.size() &&
1462 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1469 template <
typename ContainerA,
typename ContainerB>
1470 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1471 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1472 if (elemA < 0 || elemB < 0)
1483 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1487 void updateStateForNextIteration(Value v) {
1494 LogicalResult handleTransposeOp();
1497 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1512 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1517 Value tryToFoldExtractOpInPlace(Value source);
1519 ExtractOp extractOp;
1521 int64_t extractedRank;
1523 InsertOp nextInsertOp;
1524 TransposeOp nextTransposeOp;
1534 SmallVector<int64_t> sentinels;
1535 SmallVector<int64_t> extractPosition;
1539ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1541 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1542 extractedRank(extractOp.getNumIndices()) {
1543 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1544 sentinels.reserve(vectorRank - extractedRank);
1545 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1546 sentinels.push_back(-(i + 1));
1548 extractOp.getStaticPosition().end());
1554LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1556 if (extractOp.hasDynamicPosition())
1559 if (!nextTransposeOp)
1562 nextTransposeOp.getPermutation(), extractOp.getContext()));
1569ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1572 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1575 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1576 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1579 res = nextInsertOp.getValueToStore();
1588ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1590 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1593 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1603 res = nextInsertOp.getValueToStore();
1611Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1614 if (extractOp.hasDynamicPosition())
1618 bool nothingToFold = (source == extractOp.getSource());
1619 if (nothingToFold || !canFold())
1623 OpBuilder
b(extractOp.getContext());
1624 extractOp.setStaticPosition(
1626 extractOp.getSourceMutable().assign(source);
1627 return extractOp.getResult();
1631Value ExtractFromInsertTransposeChainState::fold() {
1633 if (extractOp.hasDynamicPosition())
1636 Value valueToExtractFrom = extractOp.getSource();
1637 updateStateForNextIteration(valueToExtractFrom);
1638 while (nextInsertOp || nextTransposeOp) {
1641 if (succeeded(handleTransposeOp())) {
1642 valueToExtractFrom = nextTransposeOp.getVector();
1643 updateStateForNextIteration(valueToExtractFrom);
1649 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1654 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1655 return tryToFoldExtractOpInPlace(
result);
1659 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1665 valueToExtractFrom = nextInsertOp.getDest();
1666 updateStateForNextIteration(valueToExtractFrom);
1669 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1674 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1675 auto vecType = dyn_cast<VectorType>(type);
1676 return vecType && vecType.getRank() == 0;
1686 if (isa<BroadcastOp>(op))
1689 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1697 VectorType srcType = shapeCast.getSourceVectorType();
1699 uint64_t srcRank = srcType.getRank();
1701 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1727 Operation *defOp = extractOp.getSource().getDefiningOp();
1734 if (extractOp.getType() == input.
getType())
1740 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1741 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1742 unsigned inputRank = inputType ? inputType.getRank() : 0;
1743 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1744 unsigned extractRank = extractType ? extractType.getRank() : 0;
1747 if (extractRank > inputRank)
1751 assert(inputType &&
"input must be a vector type because of previous checks");
1760 extractType.getShape() != inputShape.take_back(extractRank))
1765 unsigned deltaOverall = inputRank - extractRank;
1766 unsigned deltaBroadcast = broadcastRank - inputRank;
1770 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1771 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1774 extractOp->setOperands(
1775 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1776 extractOp.setStaticPosition(staticPos);
1777 return extractOp.getResult();
1793 if (extractOp.hasDynamicPosition())
1796 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1801 if (shuffleOp.getResultVectorType().getRank() != 1)
1804 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1805 auto shuffleMask = shuffleOp.getMask();
1806 int64_t extractIdx = extractOp.getStaticPosition()[0];
1807 int64_t shuffleIdx = shuffleMask[extractIdx];
1810 if (shuffleIdx < inputVecSize) {
1811 extractOp.setOperand(0, shuffleOp.getV1());
1812 extractOp.setStaticPosition({shuffleIdx});
1814 extractOp.setOperand(0, shuffleOp.getV2());
1815 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1818 return extractOp.getResult();
1824 if (extractOp.hasDynamicPosition())
1827 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1832 auto getDimReverse = [](VectorType type,
int64_t n) {
1833 return type.getShape().take_back(n + 1).front();
1836 llvm::isa<VectorType>(extractOp.getType())
1837 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1839 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1841 if (destinationRank > 0) {
1842 auto destinationType =
1843 llvm::cast<VectorType>(extractOp.getResult().getType());
1844 for (
int64_t i = 0; i < destinationRank; i++) {
1848 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1849 getDimReverse(destinationType, i))
1856 std::reverse(extractedPos.begin(), extractedPos.end());
1859 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1860 strides.push_back(stride);
1862 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1870 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1872 for (
int64_t i = 0; i < numDimension; i++) {
1873 newStrides.push_back(stride);
1875 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1877 std::reverse(newStrides.begin(), newStrides.end());
1881 extractOp.setStaticPosition(newPosition);
1882 extractOp.setOperand(0, shapeCastOp.getSource());
1883 return extractOp.getResult();
1889 if (extractOp.hasDynamicPosition())
1892 auto extractStridedSliceOp =
1893 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1894 if (!extractStridedSliceOp)
1903 if (extractStridedSliceOp.hasNonUnitStrides())
1909 while (!sliceOffsets.empty()) {
1910 size_t lastOffset = sliceOffsets.size() - 1;
1911 if (sliceOffsets.back() != 0 ||
1912 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1913 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1915 sliceOffsets.pop_back();
1917 unsigned destinationRank = 0;
1918 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1919 destinationRank = vecType.getRank();
1922 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1923 sliceOffsets.size())
1927 assert(extractedPos.size() >= sliceOffsets.size());
1928 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1929 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1930 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1934 extractOp.setStaticPosition(extractedPos);
1935 return extractOp.getResult();
1941 if (extractOp.hasDynamicPosition())
1945 llvm::isa<VectorType>(extractOp.getType())
1946 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1948 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1958 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1959 insertOp.getSourceVectorType().getRank();
1960 if (destinationRank > insertOp.getSourceVectorType().getRank())
1965 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1966 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1969 bool disjoint =
false;
1971 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1972 int64_t start = insertOffsets[dim];
1974 (dim < insertRankDiff)
1976 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1978 int64_t offset = extractOffsets[dim];
1980 if (start <= offset && offset < end) {
1981 if (dim >= insertRankDiff)
1982 offsetDiffs.push_back(offset - start);
1993 insertOp.getSourceVectorType().getRank() - destinationRank;
1994 for (
int64_t i = 0; i < destinationRank; i++) {
1995 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1996 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2000 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2003 extractOp.setStaticPosition(offsetDiffs);
2004 return extractOp.getResult();
2008 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2021 if (extractOp.hasDynamicPosition())
2025 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2026 if (!fromElementsOp)
2030 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2031 if (vecType.isScalable())
2035 int64_t rank = vecType.getRank();
2037 if (extractOp.getType() != vecType.getElementType())
2040 "unexpected number of indices");
2045 for (
int i = rank - 1; i >= 0; --i) {
2046 flatIndex +=
indices[i] * stride;
2047 stride *= vecType.getDimSize(i);
2049 return fromElementsOp.getElements()[flatIndex];
2054template <
typename OpType,
typename AdaptorType>
2057 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2058 OperandRange dynamicPosition = op.getDynamicPosition();
2061 if constexpr (std::is_same_v<OpType, ExtractOp>)
2062 vectorShape = op.getSourceVectorType().getShape();
2067 if (!dynamicPosition.size())
2074 bool opChange =
false;
2075 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2076 if (ShapedType::isStatic(staticPosition[i]))
2080 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2081 int64_t value = attr.getInt();
2085 staticPosition[i] = attr.getInt();
2090 operands.push_back(position);
2094 op.setStaticPosition(staticPosition);
2095 op.getOperation()->setOperands(operands);
2097 return op.getResult();
2107 if (!is_contained(staticPos, poisonVal))
2110 return ub::PoisonAttr::get(context);
2124 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2129 if (denseAttr.isSplat()) {
2131 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2136 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2137 if (vecTy.isScalable())
2140 if (extractOp.hasDynamicPosition()) {
2155 copy(extractOp.getStaticPosition(), completePositions.begin());
2158 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2161 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2163 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2166 newAttr = *denseValuesBegin;
2172OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2176 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2183 SmallVector<Value> operands = {getSource()};
2187 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2193 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2208 return inplaceFolded;
2214class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2218 LogicalResult matchAndRewrite(ExtractOp extractOp,
2219 PatternRewriter &rewriter)
const override {
2222 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2228 BroadcastableToResult::Success)
2237class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2241 LogicalResult matchAndRewrite(ExtractOp extractOp,
2242 PatternRewriter &rewriter)
const override {
2244 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2248 VectorType extractedMaskType =
2249 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2251 if (!extractedMaskType)
2254 auto maskOperands = createMaskOp.getOperands();
2255 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2256 VectorType maskType = createMaskOp.getVectorType();
2258 bool containsUnknownDims =
false;
2261 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2263 int64_t pos = extractOpPos[dimIdx];
2264 Value operand = maskOperands[dimIdx];
2265 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2268 containsUnknownDims =
true;
2272 int64_t createMaskBound =
2273 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2275 if (pos != ShapedType::kDynamic) {
2278 allFalse |= pos >= createMaskBound;
2279 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2283 containsUnknownDims =
true;
2290 }
else if (!containsUnknownDims) {
2292 extractOp, extractedMaskType,
2293 maskOperands.drop_front(extractOpPos.size()));
2302class ExtractOpFromConstantMask final :
public OpRewritePattern<ExtractOp> {
2306 LogicalResult matchAndRewrite(ExtractOp extractOp,
2307 PatternRewriter &rewriter)
const override {
2308 auto constantMaskOp =
2309 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2310 if (!constantMaskOp)
2313 Type resultType = extractOp.getResult().getType();
2314 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2316 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2317 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2319 VectorType maskType = constantMaskOp.getVectorType();
2322 for (
size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2323 int64_t pos = extractOpPos[dimIdx];
2324 if (pos == ShapedType::kDynamic) {
2327 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2336 if (pos >= maskDimSizes[dimIdx]) {
2337 if (extractedMaskType) {
2349 if (extractedMaskType) {
2353 extractOp, extractedMaskType,
2354 maskDimSizes.drop_front(extractOpPos.size()));
2367LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2368 PatternRewriter &rewriter) {
2369 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2373 VectorType sourceType = castOp.getSourceVectorType();
2374 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2378 if (sourceType.getNumElements() != targetType.getNumElements())
2382 castOp.getSource());
2392LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2393 PatternRewriter &rewriter) {
2395 if (extractOp.hasDynamicPosition())
2399 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2404 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2405 if (!fromElementsOp)
2407 VectorType inputType = fromElementsOp.getType();
2410 if (resultType.isScalable() || inputType.isScalable())
2415 SmallVector<int64_t> firstElementPos =
2416 llvm::to_vector(extractOp.getStaticPosition());
2417 firstElementPos.append(resultType.getRank(), 0);
2420 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2421 flatIndex += firstElementPos[i] * stride;
2422 stride *= inputType.getDimSize(i);
2427 extractOp, resultType,
2428 fromElementsOp.getElements().slice(flatIndex,
2429 resultType.getNumElements()));
2441struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2443 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2444 PatternRewriter &rewriter)
const override {
2445 VectorType sourceType = extractOp.getSourceVectorType();
2446 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2450 if (sourceType.getNumElements() != outType.getNumElements())
2452 extractOp,
"extract to vector with fewer elements");
2456 if (llvm::any_of(extractOp.getMixedPosition(),
2457 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2459 "leaving for extract poison folder");
2462 extractOp.getSource());
2470void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2471 MLIRContext *context) {
2472 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2473 ExtractOpFromConstantMask, ExtractToShapeCast>(context);
2474 results.
add(foldExtractFromShapeCastToShapeCast);
2475 results.
add(foldExtractFromFromElements);
2480 for (
auto attr : arrayAttr)
2481 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2488std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2499 if (operands.empty())
2502 return llvm::all_of(operands, [&](
Value operand) {
2504 return currentDef == defOp;
2522 auto fromElementsOp =
2523 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2524 if (!fromElementsOp)
2527 llvm::append_range(results, fromElementsOp.getElements());
2544 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2548 if (isa<VectorType>(bcastOp.getSource().getType()))
2551 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2553 Value scalar = bcastOp.getSource();
2554 results.assign(resultVecType.getNumElements(), scalar);
2558LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2559 SmallVectorImpl<OpFoldResult> &results) {
2564 if (
auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2565 setOperand(shapeCast.getSource());
2573ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2574 ToElementsOp::Adaptor adaptor,
2575 SmallVectorImpl<Type> &inferredReturnTypes) {
2576 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2577 Type elType = vecType.getElementType();
2578 inferredReturnTypes.append(vecType.getNumElements(), elType);
2600 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2605 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2609 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2614 int64_t dstRank = dstShape.size();
2615 int64_t srcRank = srcShape.size();
2618 auto srcElems = vector::ToElementsOp::create(
2619 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2621 int64_t dstCount = llvm::product_of(dstShape);
2624 replacements.reserve(dstCount);
2649 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2652 for (
int64_t k = 0; k < srcRank; ++k)
2653 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2656 replacements.push_back(srcElems.getResult(srcLin));
2659 rewriter.
replaceOp(toElementsOp, replacements);
2664void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2665 MLIRContext *context) {
2666 results.
add<ToElementsOfBroadcast>(context);
2686 OperandRange fromElemsOperands = fromElementsOp.getElements();
2687 if (fromElemsOperands.empty())
2690 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2698 Value toElementsInput = toElementsOp.getSource();
2699 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2700 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2701 return toElementsInput;
2721 if (llvm::any_of(elements, [](
Attribute attr) {
2727 auto destVecType = fromElementsOp.getDest().getType();
2728 auto destEltType = destVecType.getElementType();
2729 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2734 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2741OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2758 if (!llvm::all_equal(fromElementsOp.getElements()))
2761 fromElementsOp, fromElementsOp.getType(),
2762 fromElementsOp.getElements().front());
2790 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2794 if (fromElements.getType().getNumElements() == 1)
2805 for (
auto [insertIndex, element] :
2806 llvm::enumerate(fromElements.getElements())) {
2809 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2812 "element not from vector.extract");
2817 if (insertIndex == 0) {
2818 source = extractOp.getSource();
2819 }
else if (extractOp.getSource() != source) {
2821 "element from different vector");
2825 int64_t rank = position.size();
2826 assert(rank == source.getType().getRank() &&
2827 "scalar extract must have full rank position");
2838 if (insertIndex == 0) {
2839 const int64_t numElms = fromElements.getType().getNumElements();
2842 while (
index > 0 && position[
index - 1] == 0 &&
2843 numSuffixElms < numElms) {
2844 numSuffixElms *= source.getType().getDimSize(
index - 1);
2847 if (numSuffixElms != numElms) {
2849 fromElements,
"elements do not form a suffix of source");
2851 expectedPosition = llvm::to_vector(position);
2852 combinedPosition = position.drop_back(rank -
index);
2856 else if (expectedPosition != position) {
2858 fromElements,
"elements not in ascending order (static order)");
2860 increment(expectedPosition, source.getType().getShape());
2863 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2864 fromElements.getLoc(), source, combinedPosition);
2867 fromElements, fromElements.getType(), extracted);
2875 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2894void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2896 setResultRanges(getResult(), argRanges.front());
2899std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2900 return llvm::to_vector<4>(getResultVectorType().
getShape());
2905static llvm::SetVector<int64_t>
2908 int64_t rankDiff = dstShape.size() - srcShape.size();
2911 for (
auto [s1, s2] :
2912 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2914 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2922llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2924 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2927 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2943Value BroadcastOp::createOrFoldBroadcastOp(
2944 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
2945 const llvm::SetVector<int64_t> &broadcastedDims) {
2946 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2949 SmallVector<int64_t> checkShape;
2950 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2951 if (broadcastedDims.contains(i))
2953 checkShape.push_back(dstShape[i]);
2955 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2956 "ill-formed broadcastedDims contains values not confined to "
2959 Location loc = value.
getLoc();
2961 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2962 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2965 if (!srcVectorType) {
2966 assert(checkShape.empty() &&
2967 "ill-formed createOrFoldBroadcastOp arguments");
2968 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2971 assert(srcVectorType.getShape().equals(checkShape) &&
2972 "ill-formed createOrFoldBroadcastOp arguments");
2982 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2983 broadcastShape.reserve(dstShape.size());
2999 int64_t nextSrcShapeDim = broadcastedDims.size();
3000 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3001 if (broadcastedDims.contains(i)) {
3006 broadcastShape.push_back(dstShape[i]);
3007 permutation[i] = broadcastShape.size() - 1;
3013 permutation[i] = nextSrcShapeDim++;
3017 llvm::append_range(broadcastShape, srcVectorType.getShape());
3022 "unexpected \"dim-1\" broadcast");
3024 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3026 vector::BroadcastableToResult::Success &&
3027 "must be broadcastable");
3028 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3031 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3032 if (permutation[i] != i)
3033 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3039 Type srcType, VectorType dstVectorType,
3040 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3042 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3046 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3050 int64_t srcRank = srcVectorType.getRank();
3051 int64_t dstRank = dstVectorType.getRank();
3052 if (srcRank > dstRank)
3056 int64_t lead = dstRank - srcRank;
3057 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3060 bool foundMismatchingDims =
false;
3063 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3064 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3065 if (srcDim != 1 && srcDim != dstDim)
3066 foundMismatchingDims =
true;
3069 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3070 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3071 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3074 (srcDimScalableFlag != dstDimScalableFlag &&
3075 (srcDim != 1 || srcDimScalableFlag)))
3076 foundMismatchingDims =
true;
3078 if (foundMismatchingDims) {
3079 if (mismatchingDims !=
nullptr) {
3080 mismatchingDims->first.dim = srcDim;
3081 mismatchingDims->first.isScalable = srcDimScalableFlag;
3083 mismatchingDims->second.dim = dstDim;
3084 mismatchingDims->second.isScalable = dstDimScalableFlag;
3093LogicalResult BroadcastOp::verify() {
3094 std::pair<VectorDim, VectorDim> mismatchingDims;
3096 getSourceType(), getResultVectorType(), &mismatchingDims);
3100 return emitOpError(
"source rank higher than destination rank");
3103 << (mismatchingDims.first.isScalable ?
"[" :
"")
3104 << mismatchingDims.first.dim
3105 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
3106 << (mismatchingDims.second.isScalable ?
"[" :
"")
3107 << mismatchingDims.second.dim
3108 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3111 return emitOpError(
"source type is not a vector");
3112 llvm_unreachable(
"unexpected vector.broadcast op error");
3119 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3123 VectorType srcType = srcShapeCast.getSourceVectorType();
3124 VectorType destType = broadcastOp.getResultVectorType();
3132 srcShapeCast.getResultVectorType().getShape();
3135 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3136 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3137 shapecastShape.take_back(numTrailingDims)))
3140 assert(all_of(srcShape.drop_back(numTrailingDims),
3141 [](
int64_t E) { return E == 1; }) &&
3142 all_of(shapecastShape.drop_back(numTrailingDims),
3143 [](
int64_t E) { return E == 1; }) &&
3144 "ill-formed shape_cast");
3146 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3150OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3151 if (getSourceType() == getResultVectorType())
3156 if (!adaptor.getSource())
3158 auto vectorType = getResultVectorType();
3159 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3160 if (vectorType.getElementType() != attr.getType())
3164 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3165 if (vectorType.getElementType() != attr.getType())
3169 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3179struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3182 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3183 PatternRewriter &rewriter)
const override {
3184 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3188 broadcastOp.getResultVectorType(),
3189 srcBroadcast.getSource());
3202struct BroadcastToShapeCast final
3203 :
public OpRewritePattern<vector::BroadcastOp> {
3205 LogicalResult matchAndRewrite(vector::BroadcastOp
broadcast,
3206 PatternRewriter &rewriter)
const override {
3208 auto sourceType = dyn_cast<VectorType>(
broadcast.getSourceType());
3211 broadcast,
"source is a scalar, shape_cast doesn't support scalar");
3215 if (sourceType.getNumElements() != outType.getNumElements()) {
3217 broadcast,
"broadcast to a greater number of elements");
3227void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3228 MLIRContext *context) {
3229 results.
add<BroadcastFolder, BroadcastToShapeCast>(context);
3236LogicalResult ShuffleOp::verify() {
3237 VectorType resultType = getResultVectorType();
3238 VectorType v1Type = getV1VectorType();
3239 VectorType v2Type = getV2VectorType();
3241 int64_t resRank = resultType.getRank();
3242 int64_t v1Rank = v1Type.getRank();
3243 int64_t v2Rank = v2Type.getRank();
3244 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3245 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3246 if (!wellFormed0DCase && !wellFormedNDCase)
3250 for (int64_t r = 1; r < v1Rank; ++r) {
3251 int64_t resDim = resultType.getDimSize(r);
3252 int64_t v1Dim = v1Type.getDimSize(r);
3253 int64_t v2Dim = v2Type.getDimSize(r);
3254 if (resDim != v1Dim || v1Dim != v2Dim)
3258 ArrayRef<int64_t> mask = getMask();
3259 int64_t maskLength = mask.size();
3260 if (maskLength <= 0)
3262 if (maskLength != resultType.getDimSize(0))
3265 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3266 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3267 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3269 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3275ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3276 ShuffleOp::Adaptor adaptor,
3277 SmallVectorImpl<Type> &inferredReturnTypes) {
3278 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3282 auto v1Rank = v1Type.getRank();
3285 SmallVector<int64_t, 4> shape;
3286 shape.reserve(v1Rank);
3287 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3290 llvm::append_range(shape, v1Type.getShape().drop_front());
3291 inferredReturnTypes.push_back(
3292 VectorType::get(shape, v1Type.getElementType()));
3296template <
typename T>
3299 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3300 return value == expected++;
3304OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3305 auto v1Type = getV1VectorType();
3306 auto v2Type = getV2VectorType();
3308 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3309 "Vector shuffle does not support scalable vectors");
3313 if (v1Type.getRank() == 0)
3317 auto mask = getMask();
3324 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3325 if (!v1Attr || !v2Attr)
3331 if (isV1Poison && isV2Poison)
3336 if (v1Type.getRank() != 1)
3342 SmallVector<Attribute> v1Elements, v2Elements;
3343 Attribute poisonElement;
3345 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3348 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3349 poisonElement = v2Elements[0];
3352 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3355 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3356 poisonElement = v1Elements[0];
3359 SmallVector<Attribute> results;
3360 int64_t v1Size = v1Type.getDimSize(0);
3361 for (int64_t maskIdx : mask) {
3362 Attribute indexedElm;
3364 if (maskIdx == ShuffleOp::kPoisonIndex) {
3365 indexedElm = poisonElement;
3367 if (maskIdx < v1Size)
3368 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3370 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3373 results.push_back(indexedElm);
3383struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3386 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3387 PatternRewriter &rewriter)
const override {
3388 VectorType v1VectorType = shuffleOp.getV1VectorType();
3389 ArrayRef<int64_t> mask = shuffleOp.getMask();
3390 if (v1VectorType.getRank() > 0)
3392 if (mask.size() != 1)
3394 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3412static Value getScalarSplatSource(Value value) {
3418 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3425 if (isa<VectorType>(
broadcast.getSourceType()))
3433class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3437 LogicalResult matchAndRewrite(ShuffleOp op,
3438 PatternRewriter &rewriter)
const override {
3439 Value splat = getScalarSplatSource(op.getV1());
3440 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3450class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3454 LogicalResult matchAndRewrite(ShuffleOp op,
3455 PatternRewriter &rewriter)
const override {
3456 VectorType resultType = op.getResultVectorType();
3457 if (resultType.isScalable())
3459 op,
"ShuffleOp can't represent a scalable interleave");
3461 if (resultType.getRank() != 1)
3463 op,
"ShuffleOp can't represent an n-D interleave");
3465 VectorType sourceType = op.getV1VectorType();
3466 if (sourceType != op.getV2VectorType() ||
3467 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3469 op,
"ShuffleOp types don't match an interleave");
3472 ArrayRef<int64_t> shuffleMask = op.getMask();
3473 int64_t resultVectorSize = resultType.getNumElements();
3474 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3475 int64_t maskValueA = shuffleMask[i * 2];
3476 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3477 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3479 "ShuffleOp mask not interleaving");
3489void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3490 MLIRContext *context) {
3491 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3499void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3501 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3504void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3505 Value source, Value dest) {
3506 auto vectorTy = cast<VectorType>(dest.
getType());
3507 build(builder,
result, source, dest,
3508 SmallVector<int64_t>(vectorTy.getRank(), 0));
3511void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3512 Value source, Value dest, int64_t position) {
3513 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3516void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3517 Value source, Value dest, OpFoldResult position) {
3518 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3521void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3522 Value source, Value dest,
3523 ArrayRef<int64_t> position) {
3524 SmallVector<OpFoldResult> posVals;
3525 posVals.reserve(position.size());
3526 llvm::transform(position, std::back_inserter(posVals),
3528 build(builder,
result, source, dest, posVals);
3531void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3532 Value source, Value dest,
3533 ArrayRef<OpFoldResult> position) {
3534 SmallVector<int64_t> staticPos;
3535 SmallVector<Value> dynamicPos;
3537 build(builder,
result, source, dest, dynamicPos,
3541LogicalResult InsertOp::verify() {
3542 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3543 if (srcTy.getRank() == 0)
3545 "expected a scalar instead of a 0-d vector as the source operand");
3547 SmallVector<OpFoldResult> position = getMixedPosition();
3548 auto destVectorType = getDestVectorType();
3549 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3551 "expected position attribute of rank no greater than dest vector rank");
3552 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3553 if (srcVectorType &&
3554 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3555 static_cast<unsigned>(destVectorType.getRank())))
3556 return emitOpError(
"expected position attribute rank + source rank to "
3557 "match dest vector rank");
3558 if (!srcVectorType &&
3559 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3561 "expected position attribute rank to match the dest vector rank");
3562 for (
auto [idx, pos] : llvm::enumerate(position)) {
3563 if (
auto attr = dyn_cast<Attribute>(pos)) {
3564 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3566 destVectorType.getDimSize(idx))) {
3567 return emitOpError(
"expected position attribute #")
3569 <<
" to be a non-negative integer smaller than the "
3571 "dest vector dimension";
3584 assert(positions.size() <= completePositions.size() &&
3585 "positions size must be less than or equal to destTy rank");
3586 copy(positions, completePositions.begin());
3594class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3598 LogicalResult matchAndRewrite(InsertOp insertOp,
3599 PatternRewriter &rewriter)
const override {
3601 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3602 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3603 srcVecType.getNumElements())
3606 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3612class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3616 LogicalResult matchAndRewrite(InsertOp op,
3617 PatternRewriter &rewriter)
const override {
3619 Value splat = getScalarSplatSource(op.getValueToStore());
3620 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3648class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3651 LogicalResult matchAndRewrite(InsertOp op,
3652 PatternRewriter &rewriter)
const override {
3654 VectorType destTy = op.getDestVectorType();
3655 if (destTy.isScalable())
3658 for (Operation *user : op.getResult().getUsers())
3659 if (
auto insertOp = dyn_cast<InsertOp>(user))
3660 if (insertOp.getDest() == op.getResult())
3663 InsertOp currentOp = op;
3664 SmallVector<InsertOp> chainInsertOps;
3667 if (currentOp.hasDynamicPosition())
3670 chainInsertOps.push_back(currentOp);
3671 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3674 if (currentOp && !currentOp->hasOneUse())
3678 int64_t vectorSize = destTy.getNumElements();
3679 int64_t initializedCount = 0;
3680 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3681 SmallVector<int64_t> pendingInsertPos;
3682 SmallVector<int64_t> pendingInsertSize;
3683 SmallVector<Value> pendingInsertValues;
3685 for (
auto insertOp : chainInsertOps) {
3687 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3691 int64_t insertBeginPosition =
3696 int64_t insertSize = 1;
3697 if (
auto srcVectorType =
3698 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3699 insertSize = srcVectorType.getNumElements();
3701 assert(insertBeginPosition + insertSize <= vectorSize &&
3702 "insert would overflow the vector");
3704 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3705 insertBeginPosition + insertSize)) {
3706 if (initializedDestIdxs[index])
3708 initializedDestIdxs[index] =
true;
3714 pendingInsertPos.push_back(insertBeginPosition);
3715 pendingInsertSize.push_back(insertSize);
3716 pendingInsertValues.push_back(insertOp.getValueToStore());
3718 if (initializedCount == vectorSize)
3723 if (initializedCount != vectorSize)
3726 SmallVector<Value> elements(vectorSize);
3727 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3728 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3729 pendingInsertValues))) {
3730 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3732 if (!srcVectorType) {
3733 elements[insertBeginPosition] = valueToStore;
3737 SmallVector<Type> elementToInsertTypes(insertSize,
3738 srcVectorType.getElementType());
3740 auto elementsToInsert = vector::ToElementsOp::create(
3741 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3742 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3743 elements[insertBeginPosition + linearIdx] =
3744 elementsToInsert.getResult(linearIdx);
3758 int64_t maxVectorSizeFoldThreshold) {
3759 if (insertOp.hasDynamicPosition())
3762 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3770 VectorType destTy = insertOp.getDestVectorType();
3771 if (destTy.isScalable())
3775 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3776 !insertOp->hasOneUse())
3783 Type destEltType = destTy.getElementType();
3787 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3788 for (
auto value : denseSource.getValues<
Attribute>())
3794 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3795 copy(insertedValues, allValues.begin() + insertBeginPosition);
3804 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3808 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3811 insertOp.
setOperand(1, destInsert.getDest());
3812 return insertOp.getResult();
3815void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3816 MLIRContext *context) {
3817 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3818 InsertChainFullyInitialized>(context);
3821OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3824 constexpr int64_t vectorSizeFoldThreshold = 256;
3828 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3829 return getValueToStore();
3833 SmallVector<Value> operands = {getValueToStore(), getDest()};
3839 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3842 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3843 vectorSizeFoldThreshold)) {
3847 return inplaceFolded;
3854void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
3855 Value source, Value dest,
3856 ArrayRef<int64_t> offsets,
3857 ArrayRef<int64_t> strides) {
3858 result.addOperands({source, dest});
3862 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
3864 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
3869template <
typename OpType>
3873 StringRef attrName) {
3874 if (arrayAttr.size() >
shape.size())
3875 return op.emitOpError(
"expected ")
3876 << attrName <<
" attribute of rank no greater than vector rank";
3883template <
typename OpType>
3887 bool halfOpen =
true) {
3888 for (
auto attr : arrayAttr) {
3889 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3893 if (val < min || val >= upper)
3894 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3895 <<
min <<
", " << upper <<
")";
3903template <
typename OpType>
3908 for (
auto [
index, attrDimPair] :
3909 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
3910 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3914 if (val < min || val >=
max)
3915 return op.emitOpError(
"expected ")
3916 << attrName <<
" dimension " <<
index <<
" to be confined to ["
3917 <<
min <<
", " <<
max <<
")";
3927template <
typename OpType>
3932 assert(arrayAttr1.size() <=
shape.size());
3933 assert(arrayAttr2.size() <=
shape.size());
3934 for (
auto [
index, it] :
3935 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
3936 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3937 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3941 if (val1 + val2 < 0 || val1 + val2 >=
max)
3942 return op.emitOpError(
"expected sum(")
3943 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
3944 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3952 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3954 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3957LogicalResult InsertStridedSliceOp::verify() {
3958 auto sourceVectorType = getSourceVectorType();
3959 auto destVectorType = getDestVectorType();
3960 auto offsets = getOffsetsAttr();
3961 auto strides = getStridesAttr();
3962 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3964 "expected offsets of same size as destination vector rank");
3965 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3966 return emitOpError(
"expected strides of same size as source vector rank");
3967 if (sourceVectorType.getRank() > destVectorType.getRank())
3969 "expected source rank to be no greater than destination rank");
3971 auto sourceShape = sourceVectorType.getShape();
3972 auto destShape = destVectorType.getShape();
3973 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3974 destShape.size() - sourceShape.size(), 0);
3975 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3976 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3977 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3986 offName,
"source vector shape",
3990 unsigned rankDiff = destShape.size() - sourceShape.size();
3991 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3992 if (sourceVectorType.getScalableDims()[idx] !=
3993 destVectorType.getScalableDims()[idx + rankDiff]) {
3994 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3997 if (sourceVectorType.getScalableDims()[idx]) {
3998 auto sourceSize = sourceShape[idx];
3999 auto destSize = destShape[idx + rankDiff];
4000 if (sourceSize != destSize) {
4003 << (
" to match the corresponding base size from the input "
4005 << sourceSize << (
" vs ") << destSize << (
")");
4015class FoldInsertStridedSliceSplat final
4016 :
public OpRewritePattern<InsertStridedSliceOp> {
4020 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4021 PatternRewriter &rewriter)
const override {
4023 auto dst = insertStridedSliceOp.getDest();
4024 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4025 if (!splat || getScalarSplatSource(dst) != splat)
4028 rewriter.
replaceOp(insertStridedSliceOp, dst);
4035class FoldInsertStridedSliceOfExtract final
4036 :
public OpRewritePattern<InsertStridedSliceOp> {
4040 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4041 PatternRewriter &rewriter)
const override {
4042 auto extractStridedSliceOp =
4043 insertStridedSliceOp.getValueToStore()
4044 .getDefiningOp<vector::ExtractStridedSliceOp>();
4046 if (!extractStridedSliceOp)
4049 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4053 if (extractStridedSliceOp.getStrides() !=
4054 insertStridedSliceOp.getStrides() ||
4055 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4058 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4065class InsertStridedSliceConstantFolder final
4066 :
public OpRewritePattern<InsertStridedSliceOp> {
4072 static constexpr int64_t vectorSizeFoldThreshold = 256;
4074 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4075 PatternRewriter &rewriter)
const override {
4079 Attribute vectorDestCst;
4083 VectorType destTy = destVector.getType();
4084 if (destTy.isScalable())
4088 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4089 !destVector.hasOneUse())
4093 Attribute sourceCst;
4103 if (op.hasNonUnitStrides())
4106 VectorType sliceVecTy = sourceValue.getType();
4107 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4108 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4109 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
4110 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
4118 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4119 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4120 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4121 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4122 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4123 MutableArrayRef<int64_t> currSlicePosition(
4124 currDestPosition.begin() + rankDifference, currDestPosition.end());
4125 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4128 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
4129 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
4130 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4131 "Invalid slice element");
4132 newValues[linearizedPosition] = *sliceValuesIt;
4145void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4146 RewritePatternSet &results, MLIRContext *context) {
4147 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4148 InsertStridedSliceConstantFolder>(context);
4151OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4152 if (getSourceVectorType() == getDestVectorType())
4153 return getValueToStore();
4162void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4163 Value
lhs, Value
rhs, Value acc) {
4168void OuterProductOp::print(OpAsmPrinter &p) {
4169 p <<
" " << getLhs() <<
", " << getRhs();
4171 p <<
", " << getAcc();
4174 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4177ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4178 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4185 if (operandsInfo.size() < 2)
4187 "expected at least 2 operands");
4188 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4189 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4192 "expected vector type for operand #1");
4196 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4197 vRHS.getScalableDims()[0]};
4198 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4199 vLHS.getElementType(), scalableDimsRes);
4202 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4203 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4207 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4208 result.attributes.append(
4209 OuterProductOp::getKindAttrName(
result.name),
4210 CombiningKindAttr::get(
result.getContext(),
4211 OuterProductOp::getDefaultKind()));
4217 (operandsInfo.size() > 2 &&
4222LogicalResult OuterProductOp::verify() {
4223 Type tRHS = getOperandTypeRHS();
4224 VectorType vLHS = getOperandVectorTypeLHS(),
4225 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4226 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4228 if (vLHS.getRank() != 1)
4229 return emitOpError(
"expected 1-d vector for operand #1");
4233 if (vRHS.getRank() != 1)
4234 return emitOpError(
"expected 1-d vector for operand #2");
4235 if (vRES.getRank() != 2)
4237 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4238 return emitOpError(
"expected #1 operand dim to match result dim #1");
4239 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4240 return emitOpError(
"expected #2 operand dim to match result dim #2");
4241 if (vLHS.isScalable() && !vRHS.isScalable()) {
4245 "expected either both or only #2 operand dim to be scalable");
4249 if (vRES.getRank() != 1)
4251 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4252 return emitOpError(
"expected #1 operand dim to match result dim #1");
4255 if (vACC && vACC != vRES)
4256 return emitOpError(
"expected operand #3 of same type as result type");
4258 if (!getKindAttr()) {
4259 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4260 "'vector.kind<add>')");
4265 return emitOpError(
"unsupported outerproduct type");
4274Type OuterProductOp::getExpectedMaskType() {
4275 auto vecType = this->getResultVectorType();
4276 return VectorType::get(vecType.getShape(),
4277 IntegerType::get(vecType.getContext(), 1),
4278 vecType.getScalableDims());
4292 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4294 shape.reserve(vectorType.getRank());
4296 for (
unsigned e = offsets.size(); idx < e; ++idx)
4297 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4298 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4299 shape.push_back(vectorType.getShape()[idx]);
4301 return VectorType::get(
shape, vectorType.getElementType(),
4302 vectorType.getScalableDims());
4305void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4306 Value source, ArrayRef<int64_t> offsets,
4307 ArrayRef<int64_t> sizes,
4308 ArrayRef<int64_t> strides) {
4309 result.addOperands(source);
4315 offsetsAttr, sizesAttr, stridesAttr));
4316 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4318 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4320 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4324LogicalResult ExtractStridedSliceOp::verify() {
4325 auto type = getSourceVectorType();
4326 auto offsets = getOffsetsAttr();
4327 auto sizes = getSizesAttr();
4328 auto strides = getStridesAttr();
4329 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4331 "expected offsets, sizes and strides attributes of same size");
4333 auto shape = type.getShape();
4334 auto offName = getOffsetsAttrName();
4335 auto sizesName = getSizesAttrName();
4336 auto stridesName = getStridesAttrName();
4352 shape, offName, sizesName,
4357 offsets, sizes, strides);
4358 if (getResult().
getType() != resultType)
4359 return emitOpError(
"expected result type to be ") << resultType;
4361 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4362 if (type.getScalableDims()[idx]) {
4363 auto inputDim = type.getShape()[idx];
4364 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4365 if (inputDim != inputSize)
4368 << (
" to match the corresponding base size from the input "
4370 << inputSize << (
" vs ") << inputDim << (
")");
4383 auto getElement = [](
ArrayAttr array,
int idx) {
4384 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4386 ArrayAttr extractOffsets = op.getOffsets();
4389 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4391 if (op.getSourceVectorType().getRank() !=
4392 insertOp.getSourceVectorType().getRank())
4394 ArrayAttr insertOffsets = insertOp.getOffsets();
4395 ArrayAttr insertStrides = insertOp.getStrides();
4398 if (extractOffsets.size() > insertOffsets.size())
4400 bool patialoverlap =
false;
4401 bool disjoint =
false;
4403 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4404 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4406 int64_t start = getElement(insertOffsets, dim);
4407 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4408 int64_t offset = getElement(extractOffsets, dim);
4409 int64_t size = getElement(extractSizes, dim);
4411 if (start <= offset && offset < end) {
4414 if (offset + size > end)
4415 patialoverlap =
true;
4416 offsetDiffs.push_back(offset - start);
4423 if (!disjoint && !patialoverlap) {
4424 op.setOperand(insertOp.getValueToStore());
4427 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4433 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4448 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4453 if (op.hasNonUnitStrides())
4456 VectorType sourceVecTy = op.getSourceVectorType();
4460 VectorType sliceVecTy = op.getType();
4462 int64_t rank = sliceVecTy.getRank();
4474 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4476 sliceValues.reserve(sliceVecTy.getNumElements());
4480 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4482 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4483 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4485 assert(
static_cast<int64_t>(sliceValues.size()) ==
4486 sliceVecTy.getNumElements() &&
4487 "Invalid number of slice elements");
4491OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4492 if (getSourceVectorType() == getResult().
getType())
4499 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4506void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4528class StridedSliceFolder final
4529 :
public OpRewritePattern<ExtractStridedSliceOp> {
4531 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4533 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4534 PatternRewriter &rewriter)
const override {
4535 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4539 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4542 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4543 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4544 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4545 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4547 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4548 SmallVector<int64_t> combinedOffsets(newRank, 0);
4549 SmallVector<int64_t> combinedSizes(newRank);
4550 ArrayRef<int64_t> firstSourceShape =
4551 firstOp.getSourceVectorType().getShape();
4552 for (
unsigned i = 0; i < newRank; ++i) {
4553 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4554 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4555 combinedOffsets[i] = off1 + off2;
4557 if (i < secondSizes.size()) {
4558 combinedSizes[i] = secondSizes[i];
4559 }
else if (i < firstSizes.size()) {
4560 combinedSizes[i] = firstSizes[i];
4562 combinedSizes[i] = firstSourceShape[i];
4566 SmallVector<int64_t> combinedStrides(newRank, 1);
4568 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4586class StridedSliceCreateMaskFolder final
4587 :
public OpRewritePattern<ExtractStridedSliceOp> {
4591 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4592 PatternRewriter &rewriter)
const override {
4593 Location loc = extractStridedSliceOp.getLoc();
4597 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4601 if (extractStridedSliceOp.hasNonUnitStrides())
4604 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4606 SmallVector<int64_t> sliceOffsets;
4609 SmallVector<int64_t> sliceSizes;
4613 SmallVector<Value> sliceMaskDimSizes;
4614 sliceMaskDimSizes.reserve(maskDimSizes.size());
4618 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4619 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4623 IntegerAttr offsetAttr =
4625 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4626 Value sliceMaskDimSize =
4627 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4628 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4633 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4637 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4645class StridedSliceConstantMaskFolder final
4646 :
public OpRewritePattern<ExtractStridedSliceOp> {
4650 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4651 PatternRewriter &rewriter)
const override {
4654 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4655 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4656 if (!constantMaskOp)
4659 if (extractStridedSliceOp.hasNonUnitStrides())
4662 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4664 SmallVector<int64_t> sliceOffsets;
4667 SmallVector<int64_t> sliceSizes;
4671 SmallVector<int64_t> sliceMaskDimSizes;
4672 sliceMaskDimSizes.reserve(maskDimSizes.size());
4673 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4674 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4675 int64_t sliceMaskDimSize = std::max(
4676 static_cast<int64_t
>(0),
4677 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4678 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4681 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4682 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4683 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4686 if (llvm::is_contained(sliceMaskDimSizes, 0))
4687 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4692 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4700class StridedSliceBroadcast final
4701 :
public OpRewritePattern<ExtractStridedSliceOp> {
4705 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4706 PatternRewriter &rewriter)
const override {
4712 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4713 auto dstVecType = llvm::cast<VectorType>(op.getType());
4714 unsigned dstRank = dstVecType.getRank();
4715 unsigned rankDiff = dstRank - srcRank;
4719 bool needsSlice =
false;
4720 for (
unsigned i = 0; i < srcRank; i++) {
4721 if (srcVecType.getDimSize(i) != 1 &&
4722 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4729 SmallVector<int64_t> offsets =
4731 SmallVector<int64_t> sizes =
4733 for (
unsigned i = 0; i < srcRank; i++) {
4734 if (srcVecType.getDimSize(i) == 1) {
4742 source = ExtractStridedSliceOp::create(
4743 rewriter, op->getLoc(), source, offsets, sizes,
4752class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4756 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4757 PatternRewriter &rewriter)
const override {
4759 Value splat = getScalarSplatSource(op.getSource());
4783class ContiguousExtractStridedSliceToExtract final
4784 :
public OpRewritePattern<ExtractStridedSliceOp> {
4788 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4789 PatternRewriter &rewriter)
const override {
4790 if (op.hasNonUnitStrides())
4792 Value source = op.getOperand();
4793 auto sourceType = cast<VectorType>(source.
getType());
4794 if (sourceType.isScalable() || sourceType.getRank() == 0)
4803 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4804 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4811 if (numOffsets == 0)
4816 if (numOffsets == sourceType.getRank() &&
4817 static_cast<int>(sizes.size()) == sourceType.getRank())
4821 for (
int i = 0; i < numOffsets; ++i) {
4829 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4830 sizes[numOffsets] == 1) {
4835 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4836 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4845void ExtractStridedSliceOp::getCanonicalizationPatterns(
4846 RewritePatternSet &results, MLIRContext *context) {
4849 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4850 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4851 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4860void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4861 VectorType vectorType, Value source,
4863 AffineMapAttr permutationMapAttr,
4866 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4868 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4869 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4870 *padding, Value(), inBoundsAttr);
4874void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4875 VectorType vectorType, Value source,
4877 AffineMap permutationMap,
4878 std::optional<ArrayRef<bool>> inBounds) {
4879 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4880 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4883 SmallVector<bool>(vectorType.getRank(),
false));
4884 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4886 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4887 build(builder,
result, vectorType, source,
indices, *padding,
4888 permutationMapAttr, inBoundsAttr);
4892void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4893 VectorType vectorType, Value source,
4895 std::optional<ArrayRef<bool>> inBounds) {
4897 llvm::cast<ShapedType>(source.
getType()), vectorType);
4898 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4899 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4902 SmallVector<bool>(vectorType.getRank(),
false));
4903 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4905 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4906 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4908 Value(), inBoundsAttr);
4911template <
typename EmitFun>
4915 for (
auto expr : permutationMap.
getResults()) {
4916 auto dim = dyn_cast<AffineDimExpr>(expr);
4917 auto zero = dyn_cast<AffineConstantExpr>(expr);
4919 if (zero.getValue() != 0) {
4921 "requires a projected permutation_map (at most one dim or the zero "
4922 "constant can appear in each result)");
4927 return emitOpError(
"requires a projected permutation_map (at most one "
4928 "dim or the zero constant can appear in each result)");
4930 if (seen[dim.getPosition()]) {
4932 "requires a permutation_map that is a permutation (found one dim "
4933 "used more than once)");
4935 seen[dim.getPosition()] =
true;
4942 VectorType vectorType, VectorType maskType,
4943 VectorType inferredMaskType,
AffineMap permutationMap,
4945 if (op->hasAttr(
"masked")) {
4946 return op->emitOpError(
"masked attribute has been removed. "
4947 "Use in_bounds instead.");
4950 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4951 return op->emitOpError(
4952 "requires source to be a memref or ranked tensor type");
4954 auto elementType = shapedType.getElementType();
4956 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4958 unsigned sourceVecSize =
4960 vectorElementType.getShape().back();
4961 unsigned resultVecSize =
4963 vectorType.getShape().back();
4964 if (resultVecSize % sourceVecSize != 0)
4965 return op->emitOpError(
4966 "requires the bitwidth of the minor 1-D vector to be an integral "
4967 "multiple of the bitwidth of the minor 1-D vector of the source");
4969 unsigned sourceVecEltRank = vectorElementType.getRank();
4970 unsigned resultVecRank = vectorType.getRank();
4971 if (sourceVecEltRank > resultVecRank)
4972 return op->emitOpError(
4973 "requires source vector element and vector result ranks to match.");
4974 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4977 return op->emitOpError(
"requires a permutation_map with result dims of "
4978 "the same rank as the vector type");
4981 return op->emitOpError(
"does not support masks with vector element type");
4984 unsigned minorSize =
4985 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4986 unsigned resultVecSize =
4989 return op->emitOpError(
4990 "requires the bitwidth of the minor 1-D vector to be an integral "
4991 "multiple of the bitwidth of the source element type");
4995 return op->emitOpError(
"requires a permutation_map with result dims of "
4996 "the same rank as the vector type");
5000 return op->emitOpError(
"requires permutation_map without symbols");
5002 if (permutationMap.
getNumInputs() != shapedType.getRank())
5003 return op->emitOpError(
"requires a permutation_map with input dims of the "
5004 "same rank as the source type");
5006 if (maskType && maskType != inferredMaskType)
5007 return op->emitOpError(
"inferred mask type (")
5008 << inferredMaskType <<
") and mask operand type (" << maskType
5012 return op->emitOpError(
"expects the in_bounds attr of same rank "
5013 "as permutation_map results: ")
5014 << AffineMapAttr::get(permutationMap)
5015 <<
" vs inBounds of size: " << inBounds.size();
5022 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5023 if (op.getPermutationMap().isMinorIdentity())
5024 elidedAttrs.push_back(op.getPermutationMapAttrName());
5026 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
5027 elidedAttrs.push_back(op.getInBoundsAttrName());
5031void TransferReadOp::print(OpAsmPrinter &p) {
5034 p <<
", " << getMask();
5041 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
5043 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
5048 if (maskShape.empty())
5049 maskShape.push_back(1);
5054 return VectorType::get(maskShape, i1Type, scalableDims);
5071 if (hasMask.succeeded()) {
5078 if (types.size() != 2)
5079 return parser.
emitError(typesLoc,
"requires two types");
5081 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5082 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5083 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5084 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5086 return parser.
emitError(typesLoc,
"requires vector type");
5087 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
5091 if (shapedType.getRank() <
5094 "expected a custom permutation_map when "
5095 "rank(source) != rank(destination)");
5097 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5099 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5101 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
5102 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5103 if (!inBoundsAttr) {
5104 result.addAttribute(inBoundsAttrName,
5113 if (hasMask.succeeded()) {
5114 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5116 maskInfo.
location,
"does not support masks with vector element type");
5119 "expected the same rank for the vector and the "
5120 "results of the permutation map");
5128 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5130 {1, static_cast<int32_t>(indexInfo.size()), 1,
5131 static_cast<int32_t>(hasMask.succeeded())}));
5135LogicalResult TransferReadOp::verify() {
5137 ShapedType shapedType = getShapedType();
5139 VectorType maskType = getMaskType();
5140 auto paddingType = getPadding().getType();
5141 auto permutationMap = getPermutationMap();
5142 VectorType inferredMaskType =
5145 auto sourceElementType = shapedType.getElementType();
5147 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5148 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5151 shapedType, vectorType, maskType,
5152 inferredMaskType, permutationMap, getInBounds())))
5155 if (
auto sourceVectorElementType =
5156 llvm::dyn_cast<VectorType>(sourceElementType)) {
5159 if (sourceVectorElementType != paddingType)
5161 "requires source element type and padding type to match.");
5165 if (!VectorType::isValidElementType(paddingType))
5166 return emitOpError(
"requires valid padding vector elemental type");
5169 if (paddingType != sourceElementType)
5171 "requires formal padding and source of the same elemental type");
5182Type TransferReadOp::getExpectedMaskType() {
5189VectorType TransferReadOp::getVectorType() {
5190 return cast<VectorType>(getVector().
getType());
5193template <
typename TransferOp>
5197 if (op.getShapedType().isDynamicDim(indicesIdx))
5201 if (!cstOp.has_value())
5204 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5205 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5207 return cstOp.value() + vectorSize <= sourceSize;
5210template <
typename TransferOp>
5214 if (op.getTransferRank() == 0)
5217 bool changed =
false;
5219 newInBounds.reserve(op.getTransferRank());
5224 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5226 if (op.isDimInBounds(i)) {
5227 newInBounds.push_back(
true);
5232 bool inBounds =
false;
5233 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5236 dimExpr.getPosition());
5237 nonBcastDims.push_back(i);
5240 newInBounds.push_back(inBounds);
5242 changed |= inBounds;
5248 bool allNonBcastDimsInBounds = llvm::all_of(
5249 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5250 if (allNonBcastDimsInBounds) {
5252 changed |= !newInBounds[idx];
5253 newInBounds[idx] =
true;
5261 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5265template <
typename TransferOp>
5267 auto mask = op.getMask();
5274 op.getMaskMutable().clear();
5288static Value foldRAW(TransferReadOp readOp) {
5289 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5291 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5294 return defWrite.getVector();
5296 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5297 cast<VectorTransferOpInterface>(readOp.getOperation())))
5299 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5304OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5305 if (Value vec = foldRAW(*
this))
5316 return OpFoldResult();
5319std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5323void TransferReadOp::getEffects(
5324 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5326 if (llvm::isa<MemRefType>(getShapedType()))
5327 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5328 SideEffects::DefaultResource::get());
5332 if (hasPureTensorSemantics())
5339static AffineMap inverseWithUnusedDims(AffineMap map) {
5341 "expected a projected permutation map");
5346 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5376struct TransferReadAfterWriteToBroadcast
5377 :
public OpRewritePattern<TransferReadOp> {
5380 LogicalResult matchAndRewrite(TransferReadOp readOp,
5381 PatternRewriter &rewriter)
const override {
5382 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5386 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5390 if (readOp.getMask() || defWrite.getMask())
5393 if (readOp.getIndices() != defWrite.getIndices())
5396 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5400 if (readOp.getTransferChunkAccessed() !=
5401 defWrite.getTransferChunkAccessed())
5408 AffineMap readMap = readOp.getPermutationMap();
5409 AffineMap writeMap = defWrite.getPermutationMap();
5410 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5411 AffineMap composedMap = readMap.
compose(invWriteMap);
5425 int64_t numBroadcastedDims = broadcastedDims.size();
5426 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5428 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5429 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5430 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5431 invPerm[effectiveDim] = idx;
5436 VectorType readVecTy = readOp.getVectorType();
5438 auto broadcastedVecTy =
5440 readVecTy.getElementType(),
5443 Value vec = defWrite.getVector();
5444 Location loc = readOp.getLoc();
5445 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5452void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5453 MLIRContext *context) {
5454 results.
add<TransferReadAfterWriteToBroadcast>(context);
5457FailureOr<std::optional<SmallVector<Value>>>
5458TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5459 if (!hasPureBufferSemantics())
5470void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5472 AffineMapAttr permutationMapAttr,
5475 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5476 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5477 mask, inBoundsAttr);
5481void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5483 AffineMapAttr permutationMapAttr,
5485 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5486 Value(), inBoundsAttr);
5491void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5493 AffineMap permutationMap,
5494 std::optional<ArrayRef<bool>> inBounds) {
5495 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5497 (inBounds && !inBounds.value().empty())
5500 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5501 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5502 Value(), inBoundsAttr);
5507void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5509 std::optional<ArrayRef<bool>> inBounds) {
5510 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5512 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5513 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5516ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5517 OperationState &
result) {
5520 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5521 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5522 SmallVector<Type, 2> types;
5523 OpAsmParser::UnresolvedOperand maskInfo;
5529 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5534 if (types.size() != 2)
5535 return parser.
emitError(typesLoc,
"requires two types");
5537 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5539 return parser.
emitError(typesLoc,
"requires vector type");
5540 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5541 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5542 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5543 auto permMapAttrName =
5544 TransferWriteOp::getPermutationMapAttrName(
result.name);
5545 auto permMapAttr =
result.attributes.get(permMapAttrName);
5548 if (shapedType.getRank() <
5551 "expected a custom permutation_map when "
5552 "rank(source) != rank(destination)");
5554 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5556 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5558 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5559 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5560 if (!inBoundsAttr) {
5561 result.addAttribute(inBoundsAttrName,
5569 if (hasMask.succeeded()) {
5570 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5572 maskInfo.
location,
"does not support masks with vector element type");
5575 "expected the same rank for the vector and the "
5576 "results of the permutation map");
5582 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5584 {1, 1, static_cast<int32_t>(indexInfo.size()),
5585 static_cast<int32_t>(hasMask.succeeded())}));
5586 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5590void TransferWriteOp::print(OpAsmPrinter &p) {
5593 p <<
", " << getMask();
5598LogicalResult TransferWriteOp::verify() {
5600 ShapedType shapedType = getShapedType();
5602 VectorType maskType = getMaskType();
5603 auto permutationMap = getPermutationMap();
5604 VectorType inferredMaskType =
5608 if (llvm::size(
getIndices()) != shapedType.getRank())
5609 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5613 if (hasBroadcastDim())
5614 return emitOpError(
"should not have broadcast dimensions");
5617 shapedType, vectorType, maskType,
5618 inferredMaskType, permutationMap, getInBounds())))
5631Type TransferWriteOp::getExpectedMaskType() {
5638Value TransferWriteOp::getVector() {
return getOperand(0); }
5639VectorType TransferWriteOp::getVectorType() {
5640 return cast<VectorType>(getValueToStore().
getType());
5663static LogicalResult foldReadInitWrite(TransferWriteOp write,
5664 ArrayRef<Attribute>,
5665 SmallVectorImpl<OpFoldResult> &results) {
5667 if (write.getTransferRank() == 0)
5669 auto rankedTensorType =
5670 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5672 if (!rankedTensorType)
5675 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5679 if (read.getTransferRank() == 0)
5682 if (!read.getPermutationMap().isMinorIdentity() ||
5683 !write.getPermutationMap().isMinorIdentity())
5686 if (read.getTransferRank() != write.getTransferRank())
5689 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5692 if (read.getBase().getType() != rankedTensorType)
5695 if (read.getVectorType() != write.getVectorType())
5698 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5701 auto isNotConstantZero = [](Value v) {
5703 return !cstOp.has_value() || cstOp.value() != 0;
5705 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5706 llvm::any_of(write.getIndices(), isNotConstantZero))
5709 results.push_back(read.getBase());
5713static bool checkSameValueWAR(vector::TransferReadOp read,
5714 vector::TransferWriteOp write) {
5715 return read.getBase() == write.getBase() &&
5716 read.getIndices() == write.getIndices() &&
5717 read.getPermutationMap() == write.getPermutationMap() &&
5718 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5735static LogicalResult foldWAR(TransferWriteOp write,
5736 SmallVectorImpl<OpFoldResult> &results) {
5737 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5739 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5743 if (!checkSameValueWAR(read, write))
5745 results.push_back(read.getBase());
5749LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5750 SmallVectorImpl<OpFoldResult> &results) {
5751 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5753 if (succeeded(foldWAR(*
this, results)))
5765std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5769void TransferWriteOp::getEffects(
5770 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5772 if (llvm::isa<MemRefType>(getShapedType()))
5773 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5774 SideEffects::DefaultResource::get());
5778 if (hasPureTensorSemantics())
5808class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
5811 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5812 PatternRewriter &rewriter)
const override {
5813 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5815 vector::TransferWriteOp writeToModify = writeOp;
5817 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5821 writeToModify.getBaseMutable().assign(defWrite.getBase());
5826 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5827 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5831 if (!defWrite->hasOneUse())
5833 writeToModify = defWrite;
5834 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5863struct SwapExtractSliceOfTransferWrite
5864 :
public OpRewritePattern<tensor::InsertSliceOp> {
5868 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5869 PatternRewriter &rewriter)
const override {
5870 if (!insertOp.hasUnitStride())
5873 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5874 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5876 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5877 if (!transferOp || !transferOp->hasOneUse())
5882 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5884 "use-def chain is rank-reducing");
5888 if (!extractOp.hasZeroOffset()) {
5890 "ExtractSliceOp has non-zero offset");
5894 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5895 return getConstantIntValue(value) == static_cast<int64_t>(0);
5898 "TranferWriteOp has non-zero offset");
5902 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5904 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5907 for (
auto [insertSize, extractSize] :
5908 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5911 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5916 assert(transferOp.getVectorType().hasStaticShape() &&
5917 "expected vector to have a static shape");
5918 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
5920 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5921 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5923 insertOp,
"TransferWriteOp may not write the full tensor.");
5928 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
5929 auto newExtractOp = tensor::ExtractSliceOp::create(
5930 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5931 insertOp.getDest(), insertOp.getMixedOffsets(),
5932 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5933 auto newTransferWriteOp = TransferWriteOp::create(
5934 rewriter, transferOp.getLoc(), transferOp.getVector(),
5935 newExtractOp.getResult(), transferOp.getIndices(),
5936 transferOp.getPermutationMapAttr(),
5939 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5947void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5948 MLIRContext *context) {
5949 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5952FailureOr<std::optional<SmallVector<Value>>>
5953TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5954 if (!hasPureBufferSemantics())
5964static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5966 MemRefType memRefTy) {
5969 if (!vecTy.isScalable() &&
5970 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5973 if (!memRefTy.isLastDimUnitStride())
5974 return op->
emitOpError(
"most minor memref dim must have unit stride");
5978LogicalResult vector::LoadOp::verify() {
5982 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5985 if (memRefTy.getRank() < resVecTy.getRank())
5987 "destination memref has lower rank than the result vector");
5990 Type memElemTy = memRefTy.getElementType();
5991 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5992 if (memVecTy != resVecTy)
5993 return emitOpError(
"base memref and result vector types should match");
5994 memElemTy = memVecTy.getElementType();
5997 if (resVecTy.getElementType() != memElemTy)
5998 return emitOpError(
"base and result element types should match");
5999 if (llvm::size(
getIndices()) != memRefTy.getRank())
6000 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6004OpFoldResult LoadOp::fold(FoldAdaptor) {
6007 return OpFoldResult();
6010std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6014FailureOr<std::optional<SmallVector<Value>>>
6015LoadOp::bubbleDownCasts(OpBuilder &builder) {
6024LogicalResult vector::StoreOp::verify() {
6028 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
6031 if (memRefTy.getRank() < valueVecTy.getRank())
6032 return emitOpError(
"source memref has lower rank than the vector to store");
6035 Type memElemTy = memRefTy.getElementType();
6036 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6037 if (memVecTy != valueVecTy)
6039 "base memref and valueToStore vector types should match");
6040 memElemTy = memVecTy.getElementType();
6043 if (valueVecTy.getElementType() != memElemTy)
6044 return emitOpError(
"base and valueToStore element type should match");
6045 if (llvm::size(
getIndices()) != memRefTy.getRank())
6046 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
6050LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6051 SmallVectorImpl<OpFoldResult> &results) {
6055std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6059FailureOr<std::optional<SmallVector<Value>>>
6060StoreOp::bubbleDownCasts(OpBuilder &builder) {
6069LogicalResult MaskedLoadOp::verify() {
6070 VectorType maskVType = getMaskVectorType();
6071 VectorType passVType = getPassThruVectorType();
6078 if (llvm::size(
getIndices()) != memType.getRank())
6079 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6080 if (resVType.getShape() != maskVType.getShape())
6081 return emitOpError(
"expected result shape to match mask shape");
6082 if (resVType != passVType)
6083 return emitOpError(
"expected pass_thru of same type as result type");
6088class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
6091 LogicalResult matchAndRewrite(MaskedLoadOp
load,
6092 PatternRewriter &rewriter)
const override {
6104 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
6109void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6110 MLIRContext *context) {
6111 results.
add<MaskedLoadFolder>(context);
6114OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6117 return OpFoldResult();
6120FailureOr<std::optional<SmallVector<Value>>>
6121MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6130LogicalResult MaskedStoreOp::verify() {
6131 VectorType maskVType = getMaskVectorType();
6138 if (llvm::size(
getIndices()) != memType.getRank())
6139 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6140 if (valueVType.getShape() != maskVType.getShape())
6141 return emitOpError(
"expected valueToStore shape to match mask shape");
6146class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6149 LogicalResult matchAndRewrite(MaskedStoreOp store,
6150 PatternRewriter &rewriter)
const override {
6154 store, store.getValueToStore(), store.getBase(), store.getIndices());
6162 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6167void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6168 MLIRContext *context) {
6169 results.
add<MaskedStoreFolder>(context);
6172LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6173 SmallVectorImpl<OpFoldResult> &results) {
6177FailureOr<std::optional<SmallVector<Value>>>
6178MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6187LogicalResult GatherOp::verify() {
6188 VectorType indVType = getIndexVectorType();
6189 VectorType maskVType = getMaskVectorType();
6191 ShapedType baseType = getBaseType();
6193 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6194 return emitOpError(
"requires base to be a memref or ranked tensor type");
6199 if (llvm::size(getOffsets()) != baseType.getRank())
6200 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6201 if (resVType.getShape() != indVType.getShape())
6202 return emitOpError(
"expected result dim to match indices dim");
6203 if (resVType.getShape() != maskVType.getShape())
6204 return emitOpError(
"expected result dim to match mask dim");
6205 if (resVType != getPassThruVectorType())
6206 return emitOpError(
"expected pass_thru of same type as result type");
6214Type GatherOp::getExpectedMaskType() {
6215 auto vecType = this->getIndexVectorType();
6216 return VectorType::get(vecType.getShape(),
6217 IntegerType::get(vecType.getContext(), 1),
6218 vecType.getScalableDims());
6221std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6226static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6227 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6228 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6234 DenseIntElementsAttr elements;
6239 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6243class GatherFolder final :
public OpRewritePattern<GatherOp> {
6246 LogicalResult matchAndRewrite(GatherOp gather,
6247 PatternRewriter &rewriter)
const override {
6252 rewriter.
replaceOp(gather, gather.getPassThru());
6257 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6263class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6266 LogicalResult matchAndRewrite(GatherOp op,
6267 PatternRewriter &rewriter)
const override {
6268 if (!isa<MemRefType>(op.getBase().getType()))
6271 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6275 op.getOffsets(), op.getMask(),
6282void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6283 MLIRContext *context) {
6284 results.
add<GatherFolder, FoldContiguousGather>(context);
6287FailureOr<std::optional<SmallVector<Value>>>
6288GatherOp::bubbleDownCasts(OpBuilder &builder) {
6297LogicalResult ScatterOp::verify() {
6298 VectorType indVType = getIndexVectorType();
6299 VectorType maskVType = getMaskVectorType();
6301 ShapedType baseType = getBaseType();
6303 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6304 return emitOpError(
"requires base to be a memref or ranked tensor type");
6309 if (llvm::size(getOffsets()) != baseType.getRank())
6310 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6311 if (valueVType.getShape() != indVType.getShape())
6312 return emitOpError(
"expected valueToStore dim to match indices dim");
6313 if (valueVType.getShape() != maskVType.getShape())
6314 return emitOpError(
"expected valueToStore dim to match mask dim");
6318class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6321 LogicalResult matchAndRewrite(ScatterOp scatter,
6322 PatternRewriter &rewriter)
const override {
6323 ShapedType baseType = scatter.getBaseType();
6324 bool isMemRef = isa<MemRefType>(baseType);
6325 if (!isMemRef && !isa<RankedTensorType>(baseType))
6338 rewriter.
replaceOp(scatter, scatter.getBase());
6343 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6349class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6352 LogicalResult matchAndRewrite(ScatterOp op,
6353 PatternRewriter &rewriter)
const override {
6356 if (!isa<MemRefType>(op.getBase().getType()))
6359 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6363 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6369void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6370 MLIRContext *context) {
6371 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6374FailureOr<std::optional<SmallVector<Value>>>
6375ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6384LogicalResult ExpandLoadOp::verify() {
6385 VectorType maskVType = getMaskVectorType();
6386 VectorType passVType = getPassThruVectorType();
6393 if (llvm::size(
getIndices()) != memType.getRank())
6394 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6395 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6396 return emitOpError(
"expected result dim to match mask dim");
6397 if (resVType != passVType)
6398 return emitOpError(
"expected pass_thru of same type as result type");
6403class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6406 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6407 PatternRewriter &rewriter)
const override {
6411 expand, expand.getType(), expand.getBase(), expand.getIndices());
6414 rewriter.
replaceOp(expand, expand.getPassThru());
6419 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6424void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6425 MLIRContext *context) {
6426 results.
add<ExpandLoadFolder>(context);
6429FailureOr<std::optional<SmallVector<Value>>>
6430ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6439LogicalResult CompressStoreOp::verify() {
6440 VectorType maskVType = getMaskVectorType();
6447 if (llvm::size(
getIndices()) != memType.getRank())
6448 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6449 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6450 return emitOpError(
"expected valueToStore dim to match mask dim");
6455class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6458 LogicalResult matchAndRewrite(CompressStoreOp compress,
6459 PatternRewriter &rewriter)
const override {
6463 compress, compress.getValueToStore(), compress.getBase(),
6464 compress.getIndices());
6472 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6477void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6478 MLIRContext *context) {
6479 results.
add<CompressStoreFolder>(context);
6482FailureOr<std::optional<SmallVector<Value>>>
6483CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6492void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6494 setResultRanges(getResult(), argRanges.front());
6497std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6498 return llvm::to_vector<4>(getResultVectorType().
getShape());
6501LogicalResult ShapeCastOp::verify() {
6503 VectorType sourceType = getSourceVectorType();
6504 VectorType resultType = getResultVectorType();
6512 int64_t sourceNElms = sourceType.getNumElements();
6513 int64_t resultNElms = resultType.getNumElements();
6514 if (sourceNElms != resultNElms) {
6515 return emitOpError() <<
"has different number of elements at source ("
6516 << sourceNElms <<
") and result (" << resultNElms
6521 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6522 int64_t resultNScalableDims = resultType.getNumScalableDims();
6523 if (sourceNScalableDims != resultNScalableDims)
6524 return emitOpError() <<
"has different number of scalable dims at source ("
6525 << sourceNScalableDims <<
") and result ("
6526 << resultNScalableDims <<
")";
6535static bool isOrderPreserving(TransposeOp transpose) {
6536 ArrayRef<int64_t> permutation = transpose.getPermutation();
6537 VectorType sourceType = transpose.getSourceVectorType();
6538 ArrayRef<int64_t> inShape = sourceType.getShape();
6539 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6540 auto isNonScalableUnitDim = [&](int64_t dim) {
6541 return inShape[dim] == 1 && !inDimIsScalable[dim];
6543 int64_t current = 0;
6544 for (
auto p : permutation) {
6545 if (!isNonScalableUnitDim(p)) {
6555OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6557 VectorType resultType =
getType();
6560 if (getSource().
getType() == resultType)
6564 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6565 setOperand(precedingShapeCast.getSource());
6570 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6571 if (isOrderPreserving(transpose)) {
6572 setOperand(transpose.getVector());
6580 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6581 if (bcastOp.getSourceType() == resultType)
6582 return bcastOp.getSource();
6586 if (
auto denseAttr =
6587 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6588 return denseAttr.reshape(
getType());
6604static VectorType trimTrailingOneDims(VectorType oldType) {
6605 ArrayRef<int64_t> oldShape = oldType.getShape();
6606 ArrayRef<int64_t> newShape = oldShape;
6608 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6609 ArrayRef<bool> newScalableDims = oldScalableDims;
6611 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6612 newShape = newShape.drop_back(1);
6613 newScalableDims = newScalableDims.drop_back(1);
6618 if (newShape.empty()) {
6619 newShape = oldShape.take_back();
6620 newScalableDims = oldScalableDims.take_back();
6623 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6638class ShapeCastCreateMaskFolderTrailingOneDim final
6639 :
public OpRewritePattern<ShapeCastOp> {
6643 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6644 PatternRewriter &rewriter)
const override {
6645 Value shapeOpSrc = shapeOp->getOperand(0);
6646 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6647 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6648 if (!createMaskOp && !constantMaskOp)
6651 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6652 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6654 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6655 if (newVecType != shapeOpResTy)
6658 auto numDimsToDrop =
6659 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6666 auto maskOperands = createMaskOp.getOperands();
6667 auto numMaskOperands = maskOperands.size();
6670 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6672 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6673 if (!constant || (constant.value() != 1))
6676 SmallVector<Value> newMaskOperands =
6677 maskOperands.drop_back(numDimsToDrop);
6684 if (constantMaskOp) {
6685 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6686 auto numMaskOperands = maskDimSizes.size();
6689 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6691 if (maskDimSizes[i] != 1)
6695 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6706class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6710 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6711 PatternRewriter &rewriter)
const override {
6713 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6717 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6718 bool srcIsScalar = !srcVectorType;
6726 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6728 BroadcastableToResult::Success) {
6730 shapeCastOp, dstVectorType, broadcastOp.getSource());
6751class FoldShapeCastOfFromElements final :
public OpRewritePattern<ShapeCastOp> {
6755 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6756 PatternRewriter &rewriter)
const override {
6757 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
6762 shapeCastOp, shapeCastOp.getResultVectorType(),
6763 fromElements.getElements());
6770void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6771 MLIRContext *context) {
6772 results.
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
6773 FoldShapeCastOfFromElements>(context);
6780LogicalResult BitCastOp::verify() {
6781 auto sourceVectorType = getSourceVectorType();
6782 auto resultVectorType = getResultVectorType();
6784 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6785 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6786 return emitOpError(
"dimension size mismatch at: ") << i;
6789 DataLayout dataLayout = DataLayout::closest(*
this);
6790 auto sourceElementBits =
6792 auto resultElementBits =
6795 if (sourceVectorType.getRank() == 0) {
6796 if (sourceElementBits != resultElementBits)
6797 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6798 "types must be equal");
6799 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6800 resultElementBits * resultVectorType.getShape().back()) {
6802 "source/result bitwidth of the minor 1-D vectors must be equal");
6808OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6814 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6815 if (getResult().
getType() == otherOp.getSource().getType())
6816 return otherOp.getSource();
6818 setOperand(otherOp.getSource());
6822 Attribute sourceConstant = adaptor.getSource();
6823 if (!sourceConstant)
6826 Type srcElemType = getSourceVectorType().getElementType();
6827 Type dstElemType = getResultVectorType().getElementType();
6829 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6830 if (floatPack.isSplat()) {
6831 auto splat = floatPack.getSplatValue<FloatAttr>();
6834 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6835 uint32_t bits =
static_cast<uint32_t
>(
6836 splat.getValue().bitcastToAPInt().getZExtValue());
6838 bits = (bits << 16) | (bits & 0xffff);
6839 APInt intBits(32, bits);
6840 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6846 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6847 if (intPack.isSplat()) {
6848 auto splat = intPack.getSplatValue<IntegerAttr>();
6850 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.
isIntOrFloat()) {
6855 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6856 APInt intBits = splat.getValue().zext(dstBitWidth);
6859 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6860 intBits = (intBits << srcBitWidth) | intBits;
6874static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6875 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6876 SmallVector<int64_t, 8> res(memRefType.getShape());
6878 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6884void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
6886 result.addOperands(source);
6887 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6888 VectorType vectorType =
6889 VectorType::get(extractShape(memRefType),
6891 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6892 memRefType.getMemorySpace()));
6895LogicalResult TypeCastOp::verify() {
6896 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6897 if (!canonicalType.getLayout().isIdentity())
6898 return emitOpError(
"expects operand to be a memref with identity layout");
6899 if (!getResultMemRefType().getLayout().isIdentity())
6900 return emitOpError(
"expects result to be a memref with identity layout");
6901 if (getResultMemRefType().getMemorySpace() !=
6903 return emitOpError(
"expects result in same memory space");
6906 auto resultType = getResultMemRefType();
6910 "expects result and operand with same underlying scalar type: ")
6912 if (extractShape(sourceType) != extractShape(resultType))
6914 "expects concatenated result and operand shapes to be equal: ")
6923void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
6924 Value vector, ArrayRef<int64_t> permutation) {
6925 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6926 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6927 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6928 for (
unsigned i = 0; i < permutation.size(); ++i) {
6929 transposedShape[i] = vt.getShape()[permutation[i]];
6930 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6933 result.addOperands(vector);
6934 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6935 transposedScalableDims));
6936 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
6940OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6943 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6944 return splat.reshape(getResultVectorType());
6961 if (getSourceVectorType() == getResultVectorType() &&
6962 isOrderPreserving(*
this))
6968LogicalResult vector::TransposeOp::verify() {
6969 VectorType vectorType = getSourceVectorType();
6970 VectorType resultType = getResultVectorType();
6971 int64_t rank = resultType.getRank();
6972 if (vectorType.getRank() != rank)
6973 return emitOpError(
"vector result rank mismatch: ") << rank;
6975 ArrayRef<int64_t> perm = getPermutation();
6976 int64_t size = perm.size();
6978 return emitOpError(
"transposition length mismatch: ") << size;
6979 SmallVector<bool, 8> seen(rank,
false);
6980 for (
const auto &ta : llvm::enumerate(perm)) {
6981 if (ta.value() < 0 || ta.value() >= rank)
6982 return emitOpError(
"transposition index out of range: ") << ta.value();
6983 if (seen[ta.value()])
6984 return emitOpError(
"duplicate position index: ") << ta.value();
6985 seen[ta.value()] =
true;
6986 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6987 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6992std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6993 return llvm::to_vector<4>(getResultVectorType().
getShape());
6996void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6998 setResultRanges(getResult(), argRanges.front());
7004class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
7008 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7009 PatternRewriter &rewriter)
const override {
7011 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7012 ArrayRef<int64_t> permutation2) {
7013 SmallVector<int64_t, 4>
result;
7014 for (
auto index : permutation2)
7015 result.push_back(permutation1[index]);
7020 vector::TransposeOp parentTransposeOp =
7021 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7022 if (!parentTransposeOp)
7025 SmallVector<int64_t, 4> permutation = composePermutations(
7026 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7029 transposeOp, transposeOp.getResult().
getType(),
7030 parentTransposeOp.getVector(), permutation);
7036class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
7040 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7041 PatternRewriter &rewriter)
const override {
7042 Value splat = getScalarSplatSource(transposeOp.getVector());
7047 transposeOp, transposeOp.getResultVectorType(), splat);
7053class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
7057 LogicalResult matchAndRewrite(TransposeOp transpOp,
7058 PatternRewriter &rewriter)
const override {
7059 Value transposeSrc = transpOp.getVector();
7060 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
7061 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
7062 if (!createMaskOp && !constantMaskOp)
7067 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7070 auto maskOperands = createMaskOp.getOperands();
7071 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7075 transpOp, transpOp.getResultVectorType(), newOperands);
7080 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7084 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7090class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
7094 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7095 PatternRewriter &rewriter)
const override {
7097 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7100 if (!isOrderPreserving(transposeOp))
7103 VectorType resultType = transposeOp.getType();
7110 shapeCastOp.getSource());
7129class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
7132 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7133 PatternRewriter &rewriter)
const override {
7134 auto fromElementsOp =
7135 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7136 if (!fromElementsOp)
7139 VectorType srcTy = fromElementsOp.getDest().getType();
7140 VectorType dstTy = transposeOp.getType();
7142 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7143 int64_t rank = srcTy.getRank();
7146 SmallVector<int64_t> inversePerm(rank, 0);
7147 for (int64_t i = 0; i < rank; ++i)
7148 inversePerm[permutation[i]] = i;
7150 ArrayRef<int64_t> srcShape = srcTy.getShape();
7151 ArrayRef<int64_t> dstShape = dstTy.getShape();
7152 SmallVector<int64_t> srcIdx(rank, 0);
7153 SmallVector<int64_t> dstIdx(rank, 0);
7157 auto elementsOld = fromElementsOp.getElements();
7158 SmallVector<Value> elementsNew;
7159 int64_t dstNumElements = dstTy.getNumElements();
7160 elementsNew.reserve(dstNumElements);
7164 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7168 for (int64_t j = 0; j < rank; ++j)
7169 srcIdx[j] = dstIdx[inversePerm[j]];
7171 int64_t srcLin =
linearize(srcIdx, srcStrides);
7173 elementsNew.push_back(elementsOld[srcLin]);
7207class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7210 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7211 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7213 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7214 PatternRewriter &rewriter)
const override {
7220 "not preceded by a broadcast");
7223 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7224 VectorType outputType = transpose.getResultVectorType();
7227 bool inputIsScalar = !inputType;
7228 if (inputIsScalar) {
7234 ArrayRef<int64_t> permutation = transpose.getPermutation();
7235 ArrayRef<int64_t> inputShape = inputType.getShape();
7236 int64_t inputRank = inputType.getRank();
7237 int64_t outputRank = transpose.getType().getRank();
7238 int64_t deltaRank = outputRank - inputRank;
7241 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7242 bool notOne = inputShape[inputIndex] != 1;
7243 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7244 bool groupEndFound = notOne || prevNotOne;
7245 if (groupEndFound) {
7246 int high = inputIndex + deltaRank;
7250 for (
int i = low; i < high; ++i) {
7251 if (permutation[i] < low || permutation[i] >= high) {
7253 transpose,
"permutation not local to group");
7267 vector::BroadcastableToResult::Success &&
7268 "not broadcastable directly to transpose output");
7279void vector::TransposeOp::getCanonicalizationPatterns(
7280 RewritePatternSet &results, MLIRContext *context) {
7281 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7282 FoldTransposeSplat, FoldTransposeFromElements,
7283 FoldTransposeBroadcast>(context);
7290void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7292 assert(kind == ConstantMaskKind::AllTrue ||
7293 kind == ConstantMaskKind::AllFalse);
7294 build(builder,
result, type,
7295 kind == ConstantMaskKind::AllTrue
7297 : SmallVector<int64_t>(type.getRank(), 0));
7300LogicalResult ConstantMaskOp::verify() {
7301 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7303 if (resultType.getRank() == 0) {
7304 if (getMaskDimSizes().size() != 1)
7305 return emitError(
"array attr must have length 1 for 0-D vectors");
7306 auto dim = getMaskDimSizes()[0];
7307 if (dim != 0 && dim != 1)
7308 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7313 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7315 "must specify array attr of size equal vector result rank");
7318 auto resultShape = resultType.getShape();
7319 auto resultScalableDims = resultType.getScalableDims();
7320 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7321 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7322 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7324 "array attr of size out of bounds of vector result dimension size");
7325 if (resultScalableDims[index] && maskDimSize != 0 &&
7326 maskDimSize != resultShape[index])
7328 "only supports 'none set' or 'all set' scalable dimensions");
7332 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7333 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7334 if (anyZeros && !allZeros)
7335 return emitOpError(
"expected all mask dim sizes to be zeros, "
7336 "as a result of conjunction with zero mask dim");
7340bool ConstantMaskOp::isAllOnesMask() {
7343 if (resultType.getRank() == 0) {
7344 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7345 return getMaskDimSizes()[0] == 1;
7347 for (
const auto [resultSize, maskDimSize] :
7348 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7349 if (maskDimSize < resultSize)
7355OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7356 ArrayRef<int64_t> bounds = getMaskDimSizes();
7359 auto createBoolSplat = [&](
bool x) {
7365 if (vectorSizes.empty()) {
7366 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7367 return createBoolSplat(bounds[0] == 1);
7370 if (bounds == vectorSizes)
7371 return createBoolSplat(
true);
7372 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7373 return createBoolSplat(
false);
7374 return OpFoldResult();
7381void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7383 ArrayRef<OpFoldResult> mixedOperands) {
7384 SmallVector<Value> operands =
7386 build(builder,
result, type, operands);
7389LogicalResult CreateMaskOp::verify() {
7390 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7392 if (vectorType.getRank() == 0) {
7393 if (getNumOperands() != 1)
7395 "must specify exactly one operand for 0-D create_mask");
7396 }
else if (getNumOperands() !=
7397 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7399 "must specify an operand for each result vector dimension");
7429class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7433 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7434 PatternRewriter &rewriter)
const override {
7435 VectorType maskType = createMaskOp.getVectorType();
7436 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7437 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7440 constexpr std::array<int64_t, 1> rankZeroShape{1};
7441 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7442 if (maskType.getRank() == 0) {
7443 maskTypeDimSizes = rankZeroShape;
7444 maskTypeDimScalableFlags = rankZeroScalableDims;
7449 SmallVector<int64_t, 4> constantDims;
7450 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7455 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7457 constantDims.push_back(*intSize);
7461 if (vscaleMultiplier < maskTypeDimSizes[i])
7463 constantDims.push_back(*vscaleMultiplier);
7470 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7471 value = std::clamp<int64_t>(value, 0, maskDimSize);
7474 if (llvm::is_contained(constantDims, 0))
7475 constantDims.assign(constantDims.size(), 0);
7486void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7487 MLIRContext *context) {
7488 results.
add<CreateMaskFolder>(context);
7496 OpBuilder &builder, OperationState &
result, Value mask,
7497 Operation *maskableOp,
7498 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7499 assert(maskRegionBuilder &&
7500 "builder callback for 'maskRegion' must be present");
7502 result.addOperands(mask);
7503 OpBuilder::InsertionGuard guard(builder);
7504 Region *maskRegion =
result.addRegion();
7506 maskRegionBuilder(builder, maskableOp);
7511 Value mask, Operation *maskableOp,
7512 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7513 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7519 Value mask, Value passthru, Operation *maskableOp,
7520 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7521 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7523 result.addOperands(passthru);
7524 result.addTypes(resultTypes);
7527ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7529 result.regions.reserve(1);
7530 Region &maskRegion = *
result.addRegion();
7535 OpAsmParser::UnresolvedOperand mask;
7540 OpAsmParser::UnresolvedOperand passthru;
7542 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7549 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7560 SmallVector<Type> resultTypes;
7563 result.types.append(resultTypes);
7569 if (parsePassthru.succeeded()) {
7570 if (resultTypes.empty())
7573 "expects a result if passthru operand is provided");
7582void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7583 p <<
" " << getMask();
7585 p <<
", " << getPassthru();
7589 Block *singleBlock = &getMaskRegion().getBlocks().front();
7596 p <<
" : " << getMask().getType();
7597 if (getNumResults() > 0)
7598 p <<
" -> " << getResultTypes();
7601void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7604 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7605 MaskOp>::ensureTerminator(region, builder, loc);
7611 if (isa<vector::YieldOp>(block.
back()))
7619 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7620 MaskOp>::ensureTerminator(region, builder, loc);
7626 Operation *maskedOp = &block.
front();
7627 opBuilder.setInsertionPointToEnd(&block);
7628 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7631LogicalResult MaskOp::verify() {
7633 Block &block = getMaskRegion().getBlocks().
front();
7635 return emitOpError(
"expects a terminator within the mask region");
7638 if (numMaskRegionOps > 2)
7639 return emitOpError(
"expects only one operation to mask");
7642 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7644 return emitOpError(
"expects a terminator within the mask region");
7646 if (terminator->getNumOperands() != getNumResults())
7648 "expects number of results to match mask region yielded values");
7651 if (numMaskRegionOps == 1)
7654 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7656 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7660 return emitOpError(
"expects number of results to match maskable operation "
7661 "number of results");
7663 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7664 return emitOpError(
"expects all the results from the MaskableOpInterface "
7665 "to match all the values returned by the terminator");
7667 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7669 "expects result type to match maskable operation result type");
7672 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7673 return emitOpError(
"multiple vector results not supported");
7676 Type expectedMaskType = maskableOp.getExpectedMaskType();
7677 if (getMask().
getType() != expectedMaskType)
7679 << expectedMaskType <<
" mask for the maskable operation";
7682 Value passthru = getPassthru();
7684 if (!maskableOp.supportsPassthru())
7686 "doesn't expect a passthru argument for this maskable operation");
7689 return emitOpError(
"expects result when passthru argument is provided");
7692 return emitOpError(
"expects passthru type to match result type");
7712static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7713 SmallVectorImpl<OpFoldResult> &results) {
7714 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7717 Block *block = maskOp.getMaskBlock();
7718 auto terminator = cast<vector::YieldOp>(block->
front());
7719 if (terminator.getNumOperands() == 0)
7723 llvm::append_range(results, terminator.getOperands());
7727LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7728 SmallVectorImpl<OpFoldResult> &results) {
7729 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7739 Operation *maskableOp = getMaskableOp();
7745 llvm::append_range(results, maskableOp->
getResults());
7761class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
7764 LogicalResult matchAndRewrite(MaskOp maskOp,
7765 PatternRewriter &rewriter)
const override {
7766 if (!maskOp.isEmpty())
7769 if (!maskOp.hasPassthru())
7776 VectorType maskType = maskOp.getMask().getType();
7777 for (Type resultType : maskOp.getResultTypes()) {
7778 auto vecResultType = dyn_cast<VectorType>(resultType);
7779 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
7783 Block *block = maskOp.getMaskBlock();
7784 auto terminator = cast<vector::YieldOp>(block->
front());
7785 assert(terminator.getNumOperands() == 1 &&
7786 "expected one result when passthru is provided");
7789 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7790 terminator.getOperand(0), maskOp.getPassthru());
7796void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7797 MLIRContext *context) {
7798 results.
add<CanonializeEmptyMaskOp>(context);
7804Operation *MaskOp::getMaskableOp() {
7805 Block *block = getMaskBlock();
7809 return &block->
front();
7813bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
7819LogicalResult ScanOp::verify() {
7820 VectorType srcType = getSourceType();
7821 VectorType initialType = getInitialValueType();
7823 int64_t srcRank = srcType.getRank();
7824 int64_t reductionDim = getReductionDim();
7825 if (reductionDim >= srcRank)
7827 << reductionDim <<
" has to be less than " << srcRank;
7830 int64_t initialValueRank = initialType.getRank();
7831 if (initialValueRank != srcRank - 1)
7833 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7836 ArrayRef<int64_t> srcShape = srcType.getShape();
7837 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7838 SmallVector<int64_t> expectedShape;
7839 for (
int i = 0; i < srcRank; i++) {
7840 if (i != reductionDim)
7841 expectedShape.push_back(srcShape[i]);
7843 if (!llvm::equal(initialValueShapes, expectedShape)) {
7844 return emitOpError(
"incompatible input/initial value shapes");
7848 Type eltType = getDestType().getElementType();
7851 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7858 RewritePatternSet &patterns, PatternBenefit benefit) {
7860 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7861 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7862 StridedSliceConstantMaskFolder, TransposeFolder>(
7867 CombiningKind kind, Value v1, Value acc,
7868 arith::FastMathFlagsAttr fastmath,
7875 case CombiningKind::ADD:
7877 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
7878 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7879 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7881 llvm_unreachable(
"invalid value types for ADD reduction");
7883 case CombiningKind::AND:
7885 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
7887 case CombiningKind::MAXNUMF:
7888 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7889 "expected float values");
7890 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7892 case CombiningKind::MAXIMUMF:
7893 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7894 "expected float values");
7895 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7897 case CombiningKind::MINNUMF:
7898 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7899 "expected float values");
7900 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7902 case CombiningKind::MINIMUMF:
7903 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7904 "expected float values");
7905 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7907 case CombiningKind::MAXSI:
7909 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7911 case CombiningKind::MINSI:
7913 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7915 case CombiningKind::MAXUI:
7917 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7919 case CombiningKind::MINUI:
7921 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7923 case CombiningKind::MUL:
7925 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
7926 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7927 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7929 llvm_unreachable(
"invalid value types for MUL reduction");
7931 case CombiningKind::OR:
7933 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
7935 case CombiningKind::XOR:
7937 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7941 assert(
result &&
"unknown CombiningKind");
7949void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7951 auto resultType = cast<VectorType>(
getType());
7952 if (resultType.isScalable()) {
7956 APInt zero(bitwidth, 0);
7957 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7958 ConstantIntRanges
result = {zero, high, zero, high};
7959 setResultRanges(getResult(),
result);
7989struct StepCompareFolder :
public OpRewritePattern<StepOp> {
7992 LogicalResult matchAndRewrite(StepOp stepOp,
7993 PatternRewriter &rewriter)
const override {
7994 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7996 for (OpOperand &use : stepOp.getResult().getUses()) {
7997 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8002 const unsigned stepOperandNumber = use.getOperandNumber();
8003 if (stepOperandNumber != 0)
8007 unsigned constOperandNumber = 1;
8008 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8009 std::optional<int64_t> maybeConstValue =
8011 if (!maybeConstValue.has_value())
8014 int64_t constValue = maybeConstValue.value();
8015 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8017 auto maybeSplat = [&]() -> std::optional<bool> {
8019 if ((pred == arith::CmpIPredicate::ult ||
8020 pred == arith::CmpIPredicate::uge) &&
8021 stepSize <= constValue)
8022 return pred == arith::CmpIPredicate::ult;
8025 if ((pred == arith::CmpIPredicate::ule ||
8026 pred == arith::CmpIPredicate::ugt) &&
8027 stepSize - 1 <= constValue) {
8028 return pred == arith::CmpIPredicate::ule;
8032 if ((pred == arith::CmpIPredicate::eq ||
8033 pred == arith::CmpIPredicate::ne) &&
8034 stepSize <= constValue)
8035 return pred == arith::CmpIPredicate::ne;
8037 return std::nullopt;
8040 if (!maybeSplat.has_value())
8045 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8050 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8062void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8063 MLIRContext *context) {
8064 results.
add<StepCompareFolder>(context);
8074 Operation *maskableOp) {
8075 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
8087 Operation *maskableOp, Value mask,
8092 return MaskOp::create(builder, maskableOp->
getLoc(),
8095 return MaskOp::create(builder, maskableOp->
getLoc(),
8108 Value newValue, Value passthru) {
8112 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
8113 mask, newValue, passthru);
8120#define GET_ATTRDEF_CLASSES
8121#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8123#define GET_OP_CLASSES
8124#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)