42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/Casting.h"
54#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
56#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
77 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
79 for (
bool b : denseElts.getValues<
bool>())
82 else if (!
b && val <= 0)
96 auto shape = m.getType().getShape();
99 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks,
shape)) {
100 if (maskIdx < dimSize)
113 auto maskOperands = m.getOperands();
114 for (
Value operand : maskOperands) {
115 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
117 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
130 vector::YieldOp::create(builder, loc);
136 switch (combiningKind) {
137 case CombiningKind::ADD:
138 case CombiningKind::MUL:
140 case CombiningKind::MINUI:
141 case CombiningKind::MINSI:
142 case CombiningKind::MAXUI:
143 case CombiningKind::MAXSI:
144 case CombiningKind::AND:
145 case CombiningKind::OR:
146 case CombiningKind::XOR:
148 case CombiningKind::MINNUMF:
149 case CombiningKind::MAXNUMF:
150 case CombiningKind::MINIMUMF:
151 case CombiningKind::MAXIMUMF:
152 return llvm::isa<FloatType>(elementType);
182 VectorType vectorType) {
183 unsigned elementVectorRank = 0;
184 VectorType elementVectorType =
185 llvm::dyn_cast<VectorType>(shapedType.getElementType());
186 if (elementVectorType)
187 elementVectorRank += elementVectorType.getRank();
188 return vectorType.getRank() - elementVectorRank;
192 VectorType vectorType) {
195 if (shapedType.getRank() == 0 &&
201 shapedType.getRank(),
203 shapedType.getContext());
210 vector::TransferReadOp read) {
211 auto readMask = read.getMask();
212 auto writeMask = write.getMask();
218 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
219 if (!couldBeSameSplat)
236 vector::TransferReadOp read) {
237 return !defWrite.hasOutOfBoundsDim() &&
238 defWrite.getIndices() == read.getIndices() &&
239 defWrite.getVectorType() == read.getVectorType() &&
240 defWrite.getPermutationMap() == read.getPermutationMap() &&
241 ((!defWrite.getMask() && !read.getMask()) ||
246 vector::TransferWriteOp priorWrite) {
247 return priorWrite.getIndices() == write.getIndices() &&
248 priorWrite.getMask() == write.getMask() &&
249 priorWrite.getVectorType() == write.getVectorType() &&
250 priorWrite.getPermutationMap() == write.getPermutationMap();
254 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
255 bool testDynamicValueUsingBounds) {
257 if (transferA.getVectorType() != transferB.getVectorType())
259 unsigned rankOffset = transferA.getLeadingShapedRank();
260 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
261 Value indexA = transferA.getIndices()[i];
262 Value indexB = transferB.getIndices()[i];
266 if (i < rankOffset) {
269 if (cstIndexA.has_value() && cstIndexB.has_value()) {
270 if (*cstIndexA != *cstIndexB)
274 if (testDynamicValueUsingBounds) {
277 FailureOr<uint64_t> delta =
279 if (succeeded(delta) && *delta != 0)
282 FailureOr<bool> testEqual =
284 if (succeeded(testEqual) && !testEqual.value())
290 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
291 if (cstIndexA.has_value() && cstIndexB.has_value()) {
292 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
293 if (distance >= vectorDim)
297 if (testDynamicValueUsingBounds) {
300 FailureOr<int64_t> delta =
302 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
305 FailureOr<int64_t> computeDelta =
307 if (succeeded(computeDelta)) {
308 if (std::abs(computeDelta.value()) >= vectorDim)
318 VectorTransferOpInterface transferB,
319 bool testDynamicValueUsingBounds) {
320 if (transferA.getBase() != transferB.getBase())
323 testDynamicValueUsingBounds);
333 for (
auto [posInDim, dimSize, offsetInDim] :
334 llvm::reverse(llvm::zip_equal(position,
shape, offsets))) {
336 if (posInDim < dimSize + offsetInDim)
340 posInDim = offsetInDim;
350 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
352 assert(constOp &&
"Unexpected non-constant index");
353 return constOp.value();
363 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
364 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
365 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
375 llvm::transform(foldResults, std::back_inserter(values),
377 if (
auto attr = dyn_cast<Attribute>(foldResult))
379 builder, loc, cast<IntegerAttr>(attr).getInt())
382 return cast<Value>(foldResult);
395 if (
lhs.getDefiningOp<vector::VectorScaleOp>())
397 if (
rhs.getDefiningOp<vector::VectorScaleOp>())
407 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
408 if (
auto intType = dyn_cast<IntegerType>(expectedType)) {
409 if (intAttr.getType() != expectedType)
410 return IntegerAttr::get(expectedType, intAttr.getInt());
416 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
417 auto intType = dyn_cast<IntegerType>(expectedType);
421 APFloat floatVal = floatAttr.getValue();
422 APInt intVal = floatVal.bitcastToAPInt();
423 return IntegerAttr::get(expectedType, intVal);
462struct VectorInlinerInterface :
public DialectInlinerInterface {
463 using DialectInlinerInterface::DialectInlinerInterface;
472void VectorDialect::initialize() {
474#define GET_ATTRDEF_LIST
475#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
480#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
483 addInterfaces<VectorInlinerInterface>();
485 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
486 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
488 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
490 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
491 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
492 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
500 if (isa<ub::PoisonAttrInterface>(value))
503 return arith::ConstantOp::materialize(builder, value, type, loc);
519void vector::MultiDimReductionOp::build(
OpBuilder &builder,
522 CombiningKind kind) {
524 for (
const auto &en : llvm::enumerate(reductionMask))
526 reductionDims.push_back(en.index());
527 build(builder,
result, kind, source,
acc, reductionDims);
530OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
532 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
537std::optional<SmallVector<int64_t, 4>>
538MultiDimReductionOp::getShapeForUnroll() {
539 return llvm::to_vector<4>(getSourceVectorType().
getShape());
542LogicalResult MultiDimReductionOp::verify() {
545 Type inferredReturnType;
546 auto sourceScalableDims = getSourceVectorType().getScalableDims();
547 for (
auto [dimIdx, dimSize] :
548 llvm::enumerate(getSourceVectorType().
getShape()))
549 if (!llvm::any_of(getReductionDims(),
550 [dimIdx = dimIdx](
int64_t reductionDimIdx) {
551 return reductionDimIdx ==
static_cast<int64_t>(dimIdx);
553 targetShape.push_back(dimSize);
554 scalableDims.push_back(sourceScalableDims[dimIdx]);
557 if (targetShape.empty())
558 inferredReturnType = getSourceVectorType().getElementType();
560 inferredReturnType = VectorType::get(
561 targetShape, getSourceVectorType().
getElementType(), scalableDims);
562 if (
getType() != inferredReturnType)
564 <<
" is incompatible with source type "
565 << getSourceVectorType();
571Type MultiDimReductionOp::getExpectedMaskType() {
572 auto vecType = getSourceVectorType();
573 return VectorType::get(vecType.getShape(),
574 IntegerType::get(vecType.getContext(), 1),
575 vecType.getScalableDims());
584struct ElideUnitDimsInMultiDimReduction
588 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
589 PatternRewriter &rewriter)
const override {
590 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
591 for (
const auto &dim :
enumerate(shape)) {
592 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
597 OpBuilder::InsertionGuard guard(rewriter);
600 if (reductionOp.isMasked()) {
602 rootOp = reductionOp.getMaskingOp();
603 mask = reductionOp.getMaskingOp().getMask();
605 rootOp = reductionOp;
608 Location loc = reductionOp.getLoc();
609 Value acc = reductionOp.getAcc();
611 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
613 VectorType newMaskType =
614 VectorType::get(dstVecType.getShape(), rewriter.
getI1Type(),
615 dstVecType.getScalableDims());
616 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
618 cast = vector::ShapeCastOp::create(
619 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
624 mask = vector::ExtractOp::create(rewriter, loc, mask);
625 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
630 cast,
nullptr, mask);
637void MultiDimReductionOp::getCanonicalizationPatterns(
639 results.
add<ElideUnitDimsInMultiDimReduction>(context);
648 arith::FastMathFlags fastMathFlags) {
654 arith::FastMathFlags fastMathFlags) {
656 llvm::cast<VectorType>(
vector.getType()).getElementType(), kind,
vector,
660LogicalResult ReductionOp::verify() {
662 int64_t rank = getSourceVectorType().getRank();
664 return emitOpError(
"unsupported reduction rank: ") << rank;
667 Type eltType = getDest().getType();
670 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
679Type ReductionOp::getExpectedMaskType() {
680 auto vecType = getSourceVectorType();
681 return VectorType::get(vecType.getShape(),
682 IntegerType::get(vecType.getContext(), 1),
683 vecType.getScalableDims());
690 case arith::AtomicRMWKind::addf:
691 case arith::AtomicRMWKind::addi:
692 return vector::ReductionOp::create(builder,
vector.getLoc(),
693 CombiningKind::ADD,
vector);
694 case arith::AtomicRMWKind::mulf:
695 case arith::AtomicRMWKind::muli:
696 return vector::ReductionOp::create(builder,
vector.getLoc(),
697 CombiningKind::MUL,
vector);
698 case arith::AtomicRMWKind::minimumf:
699 return vector::ReductionOp::create(builder,
vector.getLoc(),
700 CombiningKind::MINIMUMF,
vector);
701 case arith::AtomicRMWKind::mins:
702 return vector::ReductionOp::create(builder,
vector.getLoc(),
703 CombiningKind::MINSI,
vector);
704 case arith::AtomicRMWKind::minu:
705 return vector::ReductionOp::create(builder,
vector.getLoc(),
706 CombiningKind::MINUI,
vector);
707 case arith::AtomicRMWKind::maximumf:
708 return vector::ReductionOp::create(builder,
vector.getLoc(),
709 CombiningKind::MAXIMUMF,
vector);
710 case arith::AtomicRMWKind::maxs:
711 return vector::ReductionOp::create(builder,
vector.getLoc(),
712 CombiningKind::MAXSI,
vector);
713 case arith::AtomicRMWKind::maxu:
714 return vector::ReductionOp::create(builder,
vector.getLoc(),
715 CombiningKind::MAXUI,
vector);
716 case arith::AtomicRMWKind::andi:
717 return vector::ReductionOp::create(builder,
vector.getLoc(),
718 CombiningKind::AND,
vector);
719 case arith::AtomicRMWKind::ori:
720 return vector::ReductionOp::create(builder,
vector.getLoc(),
721 CombiningKind::OR,
vector);
722 case arith::AtomicRMWKind::minnumf:
723 return vector::ReductionOp::create(builder,
vector.getLoc(),
724 CombiningKind::MINNUMF,
vector);
725 case arith::AtomicRMWKind::maxnumf:
726 return vector::ReductionOp::create(builder,
vector.getLoc(),
727 CombiningKind::MAXNUMF,
vector);
728 case arith::AtomicRMWKind::xori:
729 return vector::ReductionOp::create(builder,
vector.getLoc(),
730 CombiningKind::XOR,
vector);
738std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
739 return llvm::to_vector<4>(getSourceVectorType().
getShape());
746 LogicalResult matchAndRewrite(ReductionOp reductionOp,
751 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
754 if (maskableOp.isMasked()) {
756 rootOp = maskableOp.getMaskingOp();
757 mask = maskableOp.getMaskingOp().getMask();
759 rootOp = reductionOp;
762 auto vectorType = reductionOp.getSourceVectorType();
763 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
766 Location loc = reductionOp.getLoc();
768 mask = ExtractOp::create(rewriter, loc, mask);
769 Value
result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
771 if (Value acc = reductionOp.getAcc())
774 reductionOp.getFastmathAttr(), mask);
784 results.
add<ElideSingleElementReduction>(context);
798 getIndexingMapsAttrName(
result.name),
802 getIteratorTypesAttrName(
result.name),
805 return IteratorTypeAttr::get(builder.getContext(), t);
814 ContractionOp::getDefaultKind());
820 ArrayAttr iteratorTypes, CombiningKind kind) {
823 result.addAttribute(getIndexingMapsAttrName(
result.name), indexingMaps);
824 result.addAttribute(getIteratorTypesAttrName(
result.name), iteratorTypes);
826 CombiningKindAttr::get(builder.
getContext(), kind));
837 DictionaryAttr dictAttr;
851 result.attributes.append(dictAttr.getValue().begin(),
852 dictAttr.getValue().end());
858 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
859 result.attributes.get(getIteratorTypesAttrName(
result.name)));
860 if (!iteratorTypes) {
862 <<
"expected " << getIteratorTypesAttrName(
result.name)
863 <<
" array attribute";
868 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
869 auto maybeIteratorType = symbolizeIteratorType(s);
870 if (!maybeIteratorType.has_value())
871 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
873 iteratorTypeAttrs.push_back(
874 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
876 result.attributes.set(getIteratorTypesAttrName(
result.name),
879 if (!
result.attributes.get(getKindAttrName(
result.name))) {
881 getKindAttrName(
result.name),
882 CombiningKindAttr::get(
result.getContext(),
883 ContractionOp::getDefaultKind()));
885 if (masksInfo.empty())
887 if (masksInfo.size() != 2)
889 "expected zero or exactly 2 vector mask operands");
890 auto lhsType = llvm::cast<VectorType>(types[0]);
891 auto rhsType = llvm::cast<VectorType>(types[1]);
893 std::array<VectorType, 2> maskTypes = {
903 auto attrNames = getTraitAttrNames();
905 traitAttrsSet.insert_range(attrNames);
907 for (
auto attr : (*this)->getAttrs()) {
908 if (attr.getName() == getIteratorTypesAttrName()) {
910 llvm::cast<ArrayAttr>(attr.getValue())
911 .getAsValueRange<IteratorTypeAttr, IteratorType>();
917 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) ->
Attribute {
918 return StringAttr::get(
getContext(), stringifyIteratorType(t));
921 attrs.emplace_back(getIteratorTypesAttrName(),
922 ArrayAttr::get(
getContext(), iteratorTypeNames));
923 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
924 attrs.push_back(attr);
927 auto dictAttr = DictionaryAttr::get(
getContext(), attrs);
928 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
929 p << getRhs() <<
", " << getAcc();
932 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
937 const std::vector<std::pair<int64_t, int64_t>> &map) {
938 for (
auto &dimPair : map) {
939 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
940 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
941 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
948 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
950 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
951 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
954 for (
auto &dimPair : contractingDimMap) {
955 lhsContractingDimSet.insert(dimPair.first);
956 rhsContractingDimSet.insert(dimPair.second);
959 llvm::make_second_range(batchDimMap));
963 for (
int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
964 if (lhsContractingDimSet.count(i) > 0)
966 expectedResultDims.push_back(lhsType.getDimSize(i));
970 for (
int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
971 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
973 expectedResultDims.push_back(rhsType.getDimSize(i));
977 if (expectedResultDims.empty()) {
979 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
980 return op.emitOpError(
"invalid accumulator/result vector shape");
983 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
984 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
985 if (!resVectorType || !accVectorType)
986 return op.emitOpError(
"invalid accumulator/result vector shape");
992 AffineMap lhsMap = op.getIndexingMapsArray()[0];
993 AffineMap rhsMap = op.getIndexingMapsArray()[1];
995 return op.emitOpError(
996 "expected all dimensions to be either a LHS or a RHS dimension");
999 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1000 VectorType v = pair.first;
1001 auto map = pair.second;
1002 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1003 unsigned pos = map.getDimPosition(idx);
1008 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
1009 return op.emitOpError(
"expected all dimensions to get an extent as "
1010 "either a LHS or a RHS dimension");
1012 AffineMap resMap = op.getIndexingMapsArray()[2];
1017 assert(llvm::all_of(expectedMap.
getResults(),
1018 llvm::IsaPred<AffineConstantExpr>) &&
1019 "expected constant extent along all dimensions.");
1021 auto expectedShape =
1023 return cast<AffineConstantExpr>(e).getValue();
1026 VectorType::get(expectedShape, resVectorType.getElementType(),
1027 resVectorType.getScalableDims());
1028 if (resVectorType != expected || accVectorType != expected)
1029 return op.emitOpError(
1030 "invalid accumulator/result vector shape, expected: ")
1036LogicalResult ContractionOp::verify() {
1037 VectorType lhsType = getLhsType();
1038 VectorType rhsType = getRhsType();
1039 Type accType = getAccType();
1040 Type resType = getResultType();
1042 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1043 if (!lhsType.getElementType().isSignlessInteger())
1044 return emitOpError(
"only supports signless integer types");
1048 if (getIndexingMapsArray().size() != 3)
1049 return emitOpError(
"expected an indexing map for each vector operand");
1054 unsigned numIterators = getIteratorTypes().getValue().size();
1055 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1056 auto index = it.index();
1057 auto map = it.value();
1058 if (map.getNumSymbols() != 0)
1060 <<
index <<
" to have no symbols";
1061 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(
index).
getType());
1062 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1065 if (map.getNumDims() != numIterators)
1067 <<
index <<
" to have " << numIterators <<
" number of inputs";
1068 if (map.getNumResults() != rank)
1070 <<
index <<
" to have " << rank <<
" number of outputs";
1071 if (!map.isProjectedPermutation())
1073 <<
index <<
" to be a projected permutation of its inputs";
1076 auto contractingDimMap = getContractingDimMap();
1077 auto batchDimMap = getBatchDimMap();
1080 if (contractingDimMap.empty())
1081 return emitOpError(
"expected at least one contracting dimension pair");
1084 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1085 return emitOpError(
"invalid contracting dimension map");
1089 return emitOpError(
"invalid batch dimension map");
1093 contractingDimMap, batchDimMap)))
1096 if (!getKindAttr()) {
1097 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
1098 "'vector.kind<add>')");
1102 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1103 auto elementType = vectorType ? vectorType.getElementType() : resType;
1105 return emitOpError(
"unsupported contraction type");
1108 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1115Type ContractionOp::getExpectedMaskType() {
1116 auto indexingMaps = this->getIndexingMapsArray();
1119 VectorType lhsType = this->getLhsType();
1120 VectorType rhsType = this->getRhsType();
1122 unsigned numVecDims = lhsIdxMap.
getNumDims();
1128 for (
auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1131 lhsType.getScalableDims()[dimIdx];
1133 for (
auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1136 rhsType.getScalableDims()[dimIdx];
1139 assert(ShapedType::isStaticShape(maskShape) &&
1140 "Mask shape couldn't be computed");
1142 return VectorType::get(maskShape,
1143 IntegerType::get(lhsType.getContext(), 1),
1144 maskShapeScalableDims);
1149 getIteratorTypesAttrName(), getKindAttrName()};
1159static std::vector<std::pair<int64_t, int64_t>>
1161 IteratorType targetIteratorType,
MLIRContext *context) {
1162 std::vector<std::pair<int64_t, int64_t>> dimMap;
1163 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
1164 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1165 if (iteratorType != targetIteratorType)
1171 if (lhsDim >= 0 && rhsDim >= 0)
1172 dimMap.emplace_back(lhsDim, rhsDim);
1177void ContractionOp::getIterationBounds(
1179 auto lhsShape = getLhsType().getShape();
1180 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1182 for (
const auto &it : llvm::enumerate(getIteratorTypes())) {
1185 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1186 if (iteratorType == IteratorType::reduction) {
1189 assert(lhsDimIndex >= 0);
1190 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1195 assert(resDimIndex >= 0);
1196 assert(resVectorType !=
nullptr);
1197 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1201void ContractionOp::getIterationIndexMap(
1203 unsigned numMaps = getIndexingMapsArray().size();
1204 iterationIndexMap.resize(numMaps);
1205 for (
const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1206 auto index = it.index();
1207 auto map = it.value();
1208 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1209 auto dim = cast<AffineDimExpr>(map.getResult(i));
1210 iterationIndexMap[
index][dim.getPosition()] = i;
1215std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1217 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1221std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1223 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1227std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1229 getIterationBounds(
shape);
1251template <
typename AddOpType>
1257 auto canonicalize = [&](
Value maybeContraction,
1258 Value otherOperand) -> vector::ContractionOp {
1259 vector::ContractionOp contractionOp =
1260 dyn_cast_or_null<vector::ContractionOp>(
1263 return vector::ContractionOp();
1264 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1265 contractionOp.getAcc().getDefiningOp())) {
1266 if (maybeZero.getValue() ==
1267 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1269 bvm.
map(contractionOp.getAcc(), otherOperand);
1270 auto newContraction =
1271 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1272 rewriter.
replaceOp(addOp, newContraction.getResult());
1273 return newContraction;
1276 return vector::ContractionOp();
1279 Value a = addOp->getOperand(0),
b = addOp->getOperand(1);
1280 vector::ContractionOp
contract = canonicalize(a,
b);
1305 setResultRanges(getResult(), argRanges.front());
1310 auto vectorTy = cast<VectorType>(source.
getType());
1335 build(builder,
result, source, dynamicPos,
1340ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1341 ExtractOp::Adaptor adaptor,
1343 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1344 if (
static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1345 vectorType.getRank()) {
1346 inferredReturnTypes.push_back(vectorType.getElementType());
1348 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1349 vectorType.getRank());
1350 inferredReturnTypes.push_back(VectorType::get(
1351 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1352 vectorType.getScalableDims().drop_front(n)));
1360 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1361 return vectorType && vectorType.getShape().equals({1}) &&
1362 vectorType.getElementType() == r.front();
1364 if (l.size() == 1 && r.size() == 1 &&
1365 (isCompatible(l, r) || isCompatible(r, l)))
1370LogicalResult vector::ExtractOp::verify() {
1371 if (
auto resTy = dyn_cast<VectorType>(getResult().
getType()))
1372 if (resTy.getRank() == 0)
1374 "expected a scalar instead of a 0-d vector as the result type");
1377 auto dynamicMarkersCount =
1378 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1379 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1381 "mismatch between dynamic and static positions (kDynamic marker but no "
1382 "corresponding dynamic position) -- this can only happen due to an "
1383 "incorrect fold/rewrite");
1384 auto position = getMixedPosition();
1385 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1387 "expected position attribute of rank no greater than vector rank");
1388 for (
auto [idx, pos] : llvm::enumerate(position)) {
1389 if (
auto attr = dyn_cast<Attribute>(pos)) {
1390 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1392 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1393 return emitOpError(
"expected position attribute #")
1395 <<
" to be a non-negative integer smaller than the "
1396 "corresponding vector dimension or poison (-1)";
1403template <
typename IntType>
1405 return llvm::map_to_vector<4>(
1406 arrayAttr.getAsRange<IntegerAttr>(),
1407 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1413 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1417 if (extractOp.hasDynamicPosition())
1421 ExtractOp currentOp = extractOp;
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1427 if (currentOp.hasDynamicPosition())
1430 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1432 extractOp.setOperand(0, currentOp.getSource());
1435 std::reverse(globalPosition.begin(), globalPosition.end());
1436 extractOp.setStaticPosition(globalPosition);
1448class ExtractFromInsertTransposeChainState {
1450 ExtractFromInsertTransposeChainState(ExtractOp e);
1459 template <
typename ContainerA,
typename ContainerB>
1460 bool isContainedWithin(
const ContainerA &a,
const ContainerB &
b) {
1461 return a.size() <=
b.size() &&
1462 std::equal(a.begin(), a.begin() + a.size(),
b.begin());
1469 template <
typename ContainerA,
typename ContainerB>
1470 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &
b) {
1471 for (
auto [elemA, elemB] : llvm::zip(a,
b)) {
1472 if (elemA < 0 || elemB < 0)
1483 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1487 void updateStateForNextIteration(Value v) {
1494 LogicalResult handleTransposeOp();
1497 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1512 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1517 Value tryToFoldExtractOpInPlace(Value source);
1519 ExtractOp extractOp;
1521 int64_t extractedRank;
1523 InsertOp nextInsertOp;
1524 TransposeOp nextTransposeOp;
1534 SmallVector<int64_t> sentinels;
1535 SmallVector<int64_t> extractPosition;
1539ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1541 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1542 extractedRank(extractOp.getNumIndices()) {
1543 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1544 sentinels.reserve(vectorRank - extractedRank);
1545 for (
int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1546 sentinels.push_back(-(i + 1));
1548 extractOp.getStaticPosition().end());
1554LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1556 if (extractOp.hasDynamicPosition())
1559 if (!nextTransposeOp)
1562 nextTransposeOp.getPermutation(), extractOp.getContext()));
1569ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1572 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1575 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1576 if (insertedPos != llvm::ArrayRef(
extractPosition).take_front(extractedRank))
1579 res = nextInsertOp.getValueToStore();
1588ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1590 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1593 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1603 res = nextInsertOp.getValueToStore();
1611Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1614 if (extractOp.hasDynamicPosition())
1618 bool nothingToFold = (source == extractOp.getSource());
1619 if (nothingToFold || !canFold())
1623 OpBuilder
b(extractOp.getContext());
1624 extractOp.setStaticPosition(
1626 extractOp.getSourceMutable().assign(source);
1627 return extractOp.getResult();
1631Value ExtractFromInsertTransposeChainState::fold() {
1633 if (extractOp.hasDynamicPosition())
1636 Value valueToExtractFrom = extractOp.getSource();
1637 updateStateForNextIteration(valueToExtractFrom);
1638 while (nextInsertOp || nextTransposeOp) {
1641 if (succeeded(handleTransposeOp())) {
1642 valueToExtractFrom = nextTransposeOp.getVector();
1643 updateStateForNextIteration(valueToExtractFrom);
1649 if (succeeded(handleInsertOpWithMatchingPos(
result)))
1654 if (succeeded(handleInsertOpWithPrefixPos(
result)))
1655 return tryToFoldExtractOpInPlace(
result);
1659 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1665 valueToExtractFrom = nextInsertOp.getDest();
1666 updateStateForNextIteration(valueToExtractFrom);
1669 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1674 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1675 auto vecType = dyn_cast<VectorType>(type);
1676 return vecType && vecType.getRank() == 0;
1686 if (isa<BroadcastOp>(op))
1689 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1697 VectorType srcType = shapeCast.getSourceVectorType();
1699 uint64_t srcRank = srcType.getRank();
1701 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1727 Operation *defOp = extractOp.getSource().getDefiningOp();
1734 if (extractOp.getType() == input.
getType())
1740 auto inputType = llvm::dyn_cast<VectorType>(input.
getType());
1741 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1742 unsigned inputRank = inputType ? inputType.getRank() : 0;
1743 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1744 unsigned extractRank = extractType ? extractType.getRank() : 0;
1747 if (extractRank > inputRank)
1751 assert(inputType &&
"input must be a vector type because of previous checks");
1760 extractType.getShape() != inputShape.take_back(extractRank))
1765 unsigned deltaOverall = inputRank - extractRank;
1766 unsigned deltaBroadcast = broadcastRank - inputRank;
1770 for (
auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1771 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1774 extractOp->setOperands(
1775 llvm::to_vector(llvm::concat<Value>(
ValueRange(input), dynPos)));
1776 extractOp.setStaticPosition(staticPos);
1777 return extractOp.getResult();
1793 if (extractOp.hasDynamicPosition())
1796 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1801 if (shuffleOp.getResultVectorType().getRank() != 1)
1804 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1805 auto shuffleMask = shuffleOp.getMask();
1806 int64_t extractIdx = extractOp.getStaticPosition()[0];
1807 int64_t shuffleIdx = shuffleMask[extractIdx];
1810 if (shuffleIdx < inputVecSize) {
1811 extractOp.setOperand(0, shuffleOp.getV1());
1812 extractOp.setStaticPosition({shuffleIdx});
1814 extractOp.setOperand(0, shuffleOp.getV2());
1815 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1818 return extractOp.getResult();
1824 if (extractOp.hasDynamicPosition())
1827 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1832 auto getDimReverse = [](VectorType type,
int64_t n) {
1833 return type.getShape().take_back(n + 1).front();
1836 llvm::isa<VectorType>(extractOp.getType())
1837 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1839 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1841 if (destinationRank > 0) {
1842 auto destinationType =
1843 llvm::cast<VectorType>(extractOp.getResult().getType());
1844 for (
int64_t i = 0; i < destinationRank; i++) {
1848 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1849 getDimReverse(destinationType, i))
1856 std::reverse(extractedPos.begin(), extractedPos.end());
1859 for (
int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1860 strides.push_back(stride);
1862 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1870 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1872 for (
int64_t i = 0; i < numDimension; i++) {
1873 newStrides.push_back(stride);
1875 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1877 std::reverse(newStrides.begin(), newStrides.end());
1881 extractOp.setStaticPosition(newPosition);
1882 extractOp.setOperand(0, shapeCastOp.getSource());
1883 return extractOp.getResult();
1889 if (extractOp.hasDynamicPosition())
1892 auto extractStridedSliceOp =
1893 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1894 if (!extractStridedSliceOp)
1903 if (extractStridedSliceOp.hasNonUnitStrides())
1909 while (!sliceOffsets.empty()) {
1910 size_t lastOffset = sliceOffsets.size() - 1;
1911 if (sliceOffsets.back() != 0 ||
1912 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1913 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1915 sliceOffsets.pop_back();
1917 unsigned destinationRank = 0;
1918 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1919 destinationRank = vecType.getRank();
1922 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1923 sliceOffsets.size())
1927 assert(extractedPos.size() >= sliceOffsets.size());
1928 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1929 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1930 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1934 extractOp.setStaticPosition(extractedPos);
1935 return extractOp.getResult();
1941 if (extractOp.hasDynamicPosition())
1945 llvm::isa<VectorType>(extractOp.getType())
1946 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1948 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1958 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1959 insertOp.getSourceVectorType().getRank();
1960 if (destinationRank > insertOp.getSourceVectorType().getRank())
1965 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1966 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1969 bool disjoint =
false;
1971 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1972 int64_t start = insertOffsets[dim];
1974 (dim < insertRankDiff)
1976 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1978 int64_t offset = extractOffsets[dim];
1980 if (start <= offset && offset < end) {
1981 if (dim >= insertRankDiff)
1982 offsetDiffs.push_back(offset - start);
1993 insertOp.getSourceVectorType().getRank() - destinationRank;
1994 for (
int64_t i = 0; i < destinationRank; i++) {
1995 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1996 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2000 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2003 extractOp.setStaticPosition(offsetDiffs);
2004 return extractOp.getResult();
2008 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2021 if (extractOp.hasDynamicPosition())
2025 auto fromElementsOp = extractOp.getSource().
getDefiningOp<FromElementsOp>();
2026 if (!fromElementsOp)
2030 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2031 if (vecType.isScalable())
2035 int64_t rank = vecType.getRank();
2037 if (extractOp.getType() != vecType.getElementType())
2040 "unexpected number of indices");
2045 for (
int i = rank - 1; i >= 0; --i) {
2046 flatIndex +=
indices[i] * stride;
2047 stride *= vecType.getDimSize(i);
2049 return fromElementsOp.getElements()[flatIndex];
2054template <
typename OpType,
typename AdaptorType>
2057 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2058 OperandRange dynamicPosition = op.getDynamicPosition();
2061 if constexpr (std::is_same_v<OpType, ExtractOp>)
2062 vectorShape = op.getSourceVectorType().getShape();
2067 if (!dynamicPosition.size())
2074 bool opChange =
false;
2075 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2076 if (ShapedType::isStatic(staticPosition[i]))
2080 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2081 int64_t value = attr.getInt();
2085 staticPosition[i] = attr.getInt();
2090 operands.push_back(position);
2094 op.setStaticPosition(staticPosition);
2095 op.getOperation()->setOperands(operands);
2097 return op.getResult();
2107 if (!is_contained(staticPos, poisonVal))
2110 return ub::PoisonAttr::get(context);
2115 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2124 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2129 if (denseAttr.isSplat()) {
2131 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2136 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2137 if (vecTy.isScalable())
2140 if (extractOp.hasDynamicPosition()) {
2155 copy(extractOp.getStaticPosition(), completePositions.begin());
2158 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2161 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2163 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2166 newAttr = *denseValuesBegin;
2172OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2176 if (getNumIndices() == 0 && getSource().
getType() == getResult().
getType())
2183 SmallVector<Value> operands = {getSource()};
2187 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2193 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2208 return inplaceFolded;
2214class ExtractOpFromBroadcast final :
public OpRewritePattern<ExtractOp> {
2218 LogicalResult matchAndRewrite(ExtractOp extractOp,
2219 PatternRewriter &rewriter)
const override {
2222 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2228 BroadcastableToResult::Success)
2237class ExtractOpFromCreateMask final :
public OpRewritePattern<ExtractOp> {
2241 LogicalResult matchAndRewrite(ExtractOp extractOp,
2242 PatternRewriter &rewriter)
const override {
2244 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2248 VectorType extractedMaskType =
2249 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2251 if (!extractedMaskType)
2254 auto maskOperands = createMaskOp.getOperands();
2255 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2256 VectorType maskType = createMaskOp.getVectorType();
2258 bool containsUnknownDims =
false;
2261 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2263 int64_t pos = extractOpPos[dimIdx];
2264 Value operand = maskOperands[dimIdx];
2265 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2268 containsUnknownDims =
true;
2272 int64_t createMaskBound =
2273 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2275 if (pos != ShapedType::kDynamic) {
2278 allFalse |= pos >= createMaskBound;
2279 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2283 containsUnknownDims =
true;
2290 }
else if (!containsUnknownDims) {
2292 extractOp, extractedMaskType,
2293 maskOperands.drop_front(extractOpPos.size()));
2303LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2304 PatternRewriter &rewriter) {
2305 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2309 VectorType sourceType = castOp.getSourceVectorType();
2310 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2314 if (sourceType.getNumElements() != targetType.getNumElements())
2318 castOp.getSource());
2328LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2329 PatternRewriter &rewriter) {
2331 if (extractOp.hasDynamicPosition())
2335 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2340 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2341 if (!fromElementsOp)
2343 VectorType inputType = fromElementsOp.getType();
2346 if (resultType.isScalable() || inputType.isScalable())
2351 SmallVector<int64_t> firstElementPos =
2352 llvm::to_vector(extractOp.getStaticPosition());
2353 firstElementPos.append(resultType.getRank(), 0);
2356 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2357 flatIndex += firstElementPos[i] * stride;
2358 stride *= inputType.getDimSize(i);
2363 extractOp, resultType,
2364 fromElementsOp.getElements().slice(flatIndex,
2365 resultType.getNumElements()));
2377struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2379 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2380 PatternRewriter &rewriter)
const override {
2381 VectorType sourceType = extractOp.getSourceVectorType();
2382 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2386 if (sourceType.getNumElements() != outType.getNumElements())
2388 extractOp,
"extract to vector with fewer elements");
2392 if (llvm::any_of(extractOp.getMixedPosition(),
2393 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2395 "leaving for extract poison folder");
2398 extractOp.getSource());
2406void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2407 MLIRContext *context) {
2409 .
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2411 results.
add(foldExtractFromShapeCastToShapeCast);
2412 results.
add(foldExtractFromFromElements);
2417 for (
auto attr : arrayAttr)
2418 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2425std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2436 if (operands.empty())
2439 return llvm::all_of(operands, [&](
Value operand) {
2441 return currentDef == defOp;
2459 auto fromElementsOp =
2460 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2461 if (!fromElementsOp)
2464 llvm::append_range(results, fromElementsOp.getElements());
2481 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2485 if (isa<VectorType>(bcastOp.getSource().getType()))
2488 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2490 Value scalar = bcastOp.getSource();
2491 results.assign(resultVecType.getNumElements(), scalar);
2495LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2496 SmallVectorImpl<OpFoldResult> &results) {
2503ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2504 ToElementsOp::Adaptor adaptor,
2505 SmallVectorImpl<Type> &inferredReturnTypes) {
2506 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2507 Type elType = vecType.getElementType();
2508 inferredReturnTypes.append(vecType.getNumElements(), elType);
2530 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2535 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2539 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2544 int64_t dstRank = dstShape.size();
2545 int64_t srcRank = srcShape.size();
2548 auto srcElems = vector::ToElementsOp::create(
2549 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2551 int64_t dstCount = llvm::product_of(dstShape);
2554 replacements.reserve(dstCount);
2579 for (
int64_t lin = 0; lin < dstCount; ++lin) {
2582 for (
int64_t k = 0; k < srcRank; ++k)
2583 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2586 replacements.push_back(srcElems.getResult(srcLin));
2589 rewriter.
replaceOp(toElementsOp, replacements);
2594void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2595 MLIRContext *context) {
2596 results.
add<ToElementsOfBroadcast>(context);
2616 OperandRange fromElemsOperands = fromElementsOp.getElements();
2617 if (fromElemsOperands.empty())
2620 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2628 Value toElementsInput = toElementsOp.getSource();
2629 if (fromElementsOp.getType() == toElementsInput.
getType() &&
2630 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2631 return toElementsInput;
2651 if (llvm::any_of(elements, [](
Attribute attr) {
2652 return !attr || isa<ub::PoisonAttrInterface>(attr);
2657 auto destVecType = fromElementsOp.getDest().getType();
2658 auto destEltType = destVecType.getElementType();
2659 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2664 auto convertedElements = llvm::map_to_vector(elements, [&](
Attribute attr) {
2671OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2688 if (!llvm::all_equal(fromElementsOp.getElements()))
2691 fromElementsOp, fromElementsOp.getType(),
2692 fromElementsOp.getElements().front());
2720 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2724 if (fromElements.getType().getNumElements() == 1)
2735 for (
auto [insertIndex, element] :
2736 llvm::enumerate(fromElements.getElements())) {
2739 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2742 "element not from vector.extract");
2747 if (insertIndex == 0) {
2748 source = extractOp.getSource();
2749 }
else if (extractOp.getSource() != source) {
2751 "element from different vector");
2755 int64_t rank = position.size();
2756 assert(rank == source.getType().getRank() &&
2757 "scalar extract must have full rank position");
2768 if (insertIndex == 0) {
2769 const int64_t numElms = fromElements.getType().getNumElements();
2772 while (
index > 0 && position[
index - 1] == 0 &&
2773 numSuffixElms < numElms) {
2774 numSuffixElms *= source.getType().getDimSize(
index - 1);
2777 if (numSuffixElms != numElms) {
2779 fromElements,
"elements do not form a suffix of source");
2781 expectedPosition = llvm::to_vector(position);
2782 combinedPosition = position.drop_back(rank -
index);
2786 else if (expectedPosition != position) {
2788 fromElements,
"elements not in ascending order (static order)");
2790 increment(expectedPosition, source.getType().getShape());
2793 auto extracted = rewriter.
createOrFold<vector::ExtractOp>(
2794 fromElements.getLoc(), source, combinedPosition);
2797 fromElements, fromElements.getType(), extracted);
2805 for (
int dim : llvm::reverse(llvm::seq<int>(0,
indices.size()))) {
2824void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2826 setResultRanges(getResult(), argRanges.front());
2829std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2830 return llvm::to_vector<4>(getResultVectorType().
getShape());
2835static llvm::SetVector<int64_t>
2838 int64_t rankDiff = dstShape.size() - srcShape.size();
2841 for (
auto [s1, s2] :
2842 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2844 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2852llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2854 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2857 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2873Value BroadcastOp::createOrFoldBroadcastOp(
2874 OpBuilder &
b, Value value, ArrayRef<int64_t> dstShape,
2875 const llvm::SetVector<int64_t> &broadcastedDims) {
2876 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2879 SmallVector<int64_t> checkShape;
2880 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2881 if (broadcastedDims.contains(i))
2883 checkShape.push_back(dstShape[i]);
2885 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2886 "ill-formed broadcastedDims contains values not confined to "
2889 Location loc = value.
getLoc();
2891 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2892 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2895 if (!srcVectorType) {
2896 assert(checkShape.empty() &&
2897 "ill-formed createOrFoldBroadcastOp arguments");
2898 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2901 assert(srcVectorType.getShape().equals(checkShape) &&
2902 "ill-formed createOrFoldBroadcastOp arguments");
2912 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2913 broadcastShape.reserve(dstShape.size());
2929 int64_t nextSrcShapeDim = broadcastedDims.size();
2930 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2931 if (broadcastedDims.contains(i)) {
2936 broadcastShape.push_back(dstShape[i]);
2937 permutation[i] = broadcastShape.size() - 1;
2943 permutation[i] = nextSrcShapeDim++;
2947 llvm::append_range(broadcastShape, srcVectorType.getShape());
2952 "unexpected \"dim-1\" broadcast");
2954 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2956 vector::BroadcastableToResult::Success &&
2957 "must be broadcastable");
2958 Value res =
b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2961 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2962 if (permutation[i] != i)
2963 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2969 Type srcType, VectorType dstVectorType,
2970 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2972 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2976 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2980 int64_t srcRank = srcVectorType.getRank();
2981 int64_t dstRank = dstVectorType.getRank();
2982 if (srcRank > dstRank)
2986 int64_t lead = dstRank - srcRank;
2987 for (
int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2990 bool foundMismatchingDims =
false;
2993 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2994 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2995 if (srcDim != 1 && srcDim != dstDim)
2996 foundMismatchingDims =
true;
2999 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3000 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3001 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3004 (srcDimScalableFlag != dstDimScalableFlag &&
3005 (srcDim != 1 || srcDimScalableFlag)))
3006 foundMismatchingDims =
true;
3008 if (foundMismatchingDims) {
3009 if (mismatchingDims !=
nullptr) {
3010 mismatchingDims->first.dim = srcDim;
3011 mismatchingDims->first.isScalable = srcDimScalableFlag;
3013 mismatchingDims->second.dim = dstDim;
3014 mismatchingDims->second.isScalable = dstDimScalableFlag;
3023LogicalResult BroadcastOp::verify() {
3024 std::pair<VectorDim, VectorDim> mismatchingDims;
3026 getSourceType(), getResultVectorType(), &mismatchingDims);
3030 return emitOpError(
"source rank higher than destination rank");
3033 << (mismatchingDims.first.isScalable ?
"[" :
"")
3034 << mismatchingDims.first.dim
3035 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
3036 << (mismatchingDims.second.isScalable ?
"[" :
"")
3037 << mismatchingDims.second.dim
3038 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
3041 return emitOpError(
"source type is not a vector");
3042 llvm_unreachable(
"unexpected vector.broadcast op error");
3049 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3053 VectorType srcType = srcShapeCast.getSourceVectorType();
3054 VectorType destType = broadcastOp.getResultVectorType();
3062 srcShapeCast.getResultVectorType().getShape();
3065 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3066 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3067 shapecastShape.take_back(numTrailingDims)))
3070 assert(all_of(srcShape.drop_back(numTrailingDims),
3071 [](
int64_t E) { return E == 1; }) &&
3072 all_of(shapecastShape.drop_back(numTrailingDims),
3073 [](
int64_t E) { return E == 1; }) &&
3074 "ill-formed shape_cast");
3076 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3080OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3081 if (getSourceType() == getResultVectorType())
3086 if (!adaptor.getSource())
3088 auto vectorType = getResultVectorType();
3089 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3090 if (vectorType.getElementType() != attr.getType())
3094 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3095 if (vectorType.getElementType() != attr.getType())
3099 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3101 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3109struct BroadcastFolder :
public OpRewritePattern<BroadcastOp> {
3112 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3113 PatternRewriter &rewriter)
const override {
3114 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3118 broadcastOp.getResultVectorType(),
3119 srcBroadcast.getSource());
3132struct BroadcastToShapeCast final
3133 :
public OpRewritePattern<vector::BroadcastOp> {
3135 LogicalResult matchAndRewrite(vector::BroadcastOp
broadcast,
3136 PatternRewriter &rewriter)
const override {
3138 auto sourceType = dyn_cast<VectorType>(
broadcast.getSourceType());
3141 broadcast,
"source is a scalar, shape_cast doesn't support scalar");
3145 if (sourceType.getNumElements() != outType.getNumElements()) {
3147 broadcast,
"broadcast to a greater number of elements");
3157void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3158 MLIRContext *context) {
3159 results.
add<BroadcastFolder, BroadcastToShapeCast>(context);
3166LogicalResult ShuffleOp::verify() {
3167 VectorType resultType = getResultVectorType();
3168 VectorType v1Type = getV1VectorType();
3169 VectorType v2Type = getV2VectorType();
3171 int64_t resRank = resultType.getRank();
3172 int64_t v1Rank = v1Type.getRank();
3173 int64_t v2Rank = v2Type.getRank();
3174 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3175 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3176 if (!wellFormed0DCase && !wellFormedNDCase)
3180 for (int64_t r = 1; r < v1Rank; ++r) {
3181 int64_t resDim = resultType.getDimSize(r);
3182 int64_t v1Dim = v1Type.getDimSize(r);
3183 int64_t v2Dim = v2Type.getDimSize(r);
3184 if (resDim != v1Dim || v1Dim != v2Dim)
3188 ArrayRef<int64_t> mask = getMask();
3189 int64_t maskLength = mask.size();
3190 if (maskLength <= 0)
3192 if (maskLength != resultType.getDimSize(0))
3195 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3196 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3197 for (
auto [idx, maskPos] : llvm::enumerate(mask)) {
3199 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
3205ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3206 ShuffleOp::Adaptor adaptor,
3207 SmallVectorImpl<Type> &inferredReturnTypes) {
3208 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3209 auto v1Rank = v1Type.getRank();
3212 SmallVector<int64_t, 4> shape;
3213 shape.reserve(v1Rank);
3214 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3217 llvm::append_range(shape, v1Type.getShape().drop_front());
3218 inferredReturnTypes.push_back(
3219 VectorType::get(shape, v1Type.getElementType()));
3223template <
typename T>
3226 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3227 return value == expected++;
3231OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3232 auto v1Type = getV1VectorType();
3233 auto v2Type = getV2VectorType();
3235 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3236 "Vector shuffle does not support scalable vectors");
3240 if (v1Type.getRank() == 0)
3244 auto mask = getMask();
3251 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3252 if (!v1Attr || !v2Attr)
3256 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3257 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3258 if (isV1Poison && isV2Poison)
3263 if (v1Type.getRank() != 1)
3269 SmallVector<Attribute> v1Elements, v2Elements;
3270 Attribute poisonElement;
3272 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3275 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3276 poisonElement = v2Elements[0];
3279 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3282 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3283 poisonElement = v1Elements[0];
3286 SmallVector<Attribute> results;
3287 int64_t v1Size = v1Type.getDimSize(0);
3288 for (int64_t maskIdx : mask) {
3289 Attribute indexedElm;
3291 if (maskIdx == ShuffleOp::kPoisonIndex) {
3292 indexedElm = poisonElement;
3294 if (maskIdx < v1Size)
3295 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3297 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3300 results.push_back(indexedElm);
3310struct Canonicalize0DShuffleOp :
public OpRewritePattern<ShuffleOp> {
3313 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3314 PatternRewriter &rewriter)
const override {
3315 VectorType v1VectorType = shuffleOp.getV1VectorType();
3316 ArrayRef<int64_t> mask = shuffleOp.getMask();
3317 if (v1VectorType.getRank() > 0)
3319 if (mask.size() != 1)
3321 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3339static Value getScalarSplatSource(Value value) {
3345 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3352 if (isa<VectorType>(
broadcast.getSourceType()))
3360class ShuffleSplat final :
public OpRewritePattern<ShuffleOp> {
3364 LogicalResult matchAndRewrite(ShuffleOp op,
3365 PatternRewriter &rewriter)
const override {
3366 Value splat = getScalarSplatSource(op.getV1());
3367 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3377class ShuffleInterleave :
public OpRewritePattern<ShuffleOp> {
3381 LogicalResult matchAndRewrite(ShuffleOp op,
3382 PatternRewriter &rewriter)
const override {
3383 VectorType resultType = op.getResultVectorType();
3384 if (resultType.isScalable())
3386 op,
"ShuffleOp can't represent a scalable interleave");
3388 if (resultType.getRank() != 1)
3390 op,
"ShuffleOp can't represent an n-D interleave");
3392 VectorType sourceType = op.getV1VectorType();
3393 if (sourceType != op.getV2VectorType() ||
3394 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3396 op,
"ShuffleOp types don't match an interleave");
3399 ArrayRef<int64_t> shuffleMask = op.getMask();
3400 int64_t resultVectorSize = resultType.getNumElements();
3401 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3402 int64_t maskValueA = shuffleMask[i * 2];
3403 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3404 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3406 "ShuffleOp mask not interleaving");
3416void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3417 MLIRContext *context) {
3418 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3426void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3428 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3431void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3432 Value source, Value dest) {
3433 auto vectorTy = cast<VectorType>(dest.
getType());
3434 build(builder,
result, source, dest,
3435 SmallVector<int64_t>(vectorTy.getRank(), 0));
3438void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3439 Value source, Value dest, int64_t position) {
3440 build(builder,
result, source, dest, ArrayRef<int64_t>{position});
3443void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3444 Value source, Value dest, OpFoldResult position) {
3445 build(builder,
result, source, dest, ArrayRef<OpFoldResult>{position});
3448void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3449 Value source, Value dest,
3450 ArrayRef<int64_t> position) {
3451 SmallVector<OpFoldResult> posVals;
3452 posVals.reserve(position.size());
3453 llvm::transform(position, std::back_inserter(posVals),
3455 build(builder,
result, source, dest, posVals);
3458void vector::InsertOp::build(OpBuilder &builder, OperationState &
result,
3459 Value source, Value dest,
3460 ArrayRef<OpFoldResult> position) {
3461 SmallVector<int64_t> staticPos;
3462 SmallVector<Value> dynamicPos;
3464 build(builder,
result, source, dest, dynamicPos,
3468LogicalResult InsertOp::verify() {
3469 if (
auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3470 if (srcTy.getRank() == 0)
3472 "expected a scalar instead of a 0-d vector as the source operand");
3474 SmallVector<OpFoldResult> position = getMixedPosition();
3475 auto destVectorType = getDestVectorType();
3476 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3478 "expected position attribute of rank no greater than dest vector rank");
3479 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3480 if (srcVectorType &&
3481 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3482 static_cast<unsigned>(destVectorType.getRank())))
3483 return emitOpError(
"expected position attribute rank + source rank to "
3484 "match dest vector rank");
3485 if (!srcVectorType &&
3486 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3488 "expected position attribute rank to match the dest vector rank");
3489 for (
auto [idx, pos] : llvm::enumerate(position)) {
3490 if (
auto attr = dyn_cast<Attribute>(pos)) {
3491 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3493 destVectorType.getDimSize(idx))) {
3494 return emitOpError(
"expected position attribute #")
3496 <<
" to be a non-negative integer smaller than the "
3498 "dest vector dimension";
3511 assert(positions.size() <= completePositions.size() &&
3512 "positions size must be less than or equal to destTy rank");
3513 copy(positions, completePositions.begin());
3521class InsertToBroadcast final :
public OpRewritePattern<InsertOp> {
3525 LogicalResult matchAndRewrite(InsertOp insertOp,
3526 PatternRewriter &rewriter)
const override {
3528 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3529 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3530 srcVecType.getNumElements())
3533 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3539class InsertSplatToSplat final :
public OpRewritePattern<InsertOp> {
3543 LogicalResult matchAndRewrite(InsertOp op,
3544 PatternRewriter &rewriter)
const override {
3546 Value splat = getScalarSplatSource(op.getValueToStore());
3547 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3575class InsertChainFullyInitialized final :
public OpRewritePattern<InsertOp> {
3578 LogicalResult matchAndRewrite(InsertOp op,
3579 PatternRewriter &rewriter)
const override {
3581 VectorType destTy = op.getDestVectorType();
3582 if (destTy.isScalable())
3585 for (Operation *user : op.getResult().getUsers())
3586 if (
auto insertOp = dyn_cast<InsertOp>(user))
3587 if (insertOp.getDest() == op.getResult())
3590 InsertOp currentOp = op;
3591 SmallVector<InsertOp> chainInsertOps;
3594 if (currentOp.hasDynamicPosition())
3597 chainInsertOps.push_back(currentOp);
3598 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3601 if (currentOp && !currentOp->hasOneUse())
3605 int64_t vectorSize = destTy.getNumElements();
3606 int64_t initializedCount = 0;
3607 SmallVector<bool> initializedDestIdxs(vectorSize,
false);
3608 SmallVector<int64_t> pendingInsertPos;
3609 SmallVector<int64_t> pendingInsertSize;
3610 SmallVector<Value> pendingInsertValues;
3612 for (
auto insertOp : chainInsertOps) {
3614 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3618 int64_t insertBeginPosition =
3623 int64_t insertSize = 1;
3624 if (
auto srcVectorType =
3625 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3626 insertSize = srcVectorType.getNumElements();
3628 assert(insertBeginPosition + insertSize <= vectorSize &&
3629 "insert would overflow the vector");
3631 for (
auto index : llvm::seq<int64_t>(insertBeginPosition,
3632 insertBeginPosition + insertSize)) {
3633 if (initializedDestIdxs[index])
3635 initializedDestIdxs[index] =
true;
3641 pendingInsertPos.push_back(insertBeginPosition);
3642 pendingInsertSize.push_back(insertSize);
3643 pendingInsertValues.push_back(insertOp.getValueToStore());
3645 if (initializedCount == vectorSize)
3650 if (initializedCount != vectorSize)
3653 SmallVector<Value> elements(vectorSize);
3654 for (
auto [insertBeginPosition, insertSize, valueToStore] :
3655 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3656 pendingInsertValues))) {
3657 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3659 if (!srcVectorType) {
3660 elements[insertBeginPosition] = valueToStore;
3664 SmallVector<Type> elementToInsertTypes(insertSize,
3665 srcVectorType.getElementType());
3667 auto elementsToInsert = vector::ToElementsOp::create(
3668 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3669 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3670 elements[insertBeginPosition + linearIdx] =
3671 elementsToInsert.getResult(linearIdx);
3685 int64_t maxVectorSizeFoldThreshold) {
3686 if (insertOp.hasDynamicPosition())
3689 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3697 VectorType destTy = insertOp.getDestVectorType();
3698 if (destTy.isScalable())
3702 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3703 !insertOp->hasOneUse())
3710 Type destEltType = destTy.getElementType();
3714 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3715 for (
auto value : denseSource.getValues<
Attribute>())
3721 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3722 copy(insertedValues, allValues.begin() + insertBeginPosition);
3731 auto destInsert = insertOp.getDest().
getDefiningOp<InsertOp>();
3735 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3738 insertOp.
setOperand(1, destInsert.getDest());
3739 return insertOp.getResult();
3742void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3743 MLIRContext *context) {
3744 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3745 InsertChainFullyInitialized>(context);
3748OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3751 constexpr int64_t vectorSizeFoldThreshold = 256;
3755 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3756 return getValueToStore();
3760 SmallVector<Value> operands = {getValueToStore(), getDest()};
3766 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3769 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3770 vectorSizeFoldThreshold)) {
3774 return inplaceFolded;
3781void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
3782 Value source, Value dest,
3783 ArrayRef<int64_t> offsets,
3784 ArrayRef<int64_t> strides) {
3785 result.addOperands({source, dest});
3789 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(
result.name),
3791 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(
result.name),
3796template <
typename OpType>
3800 StringRef attrName) {
3801 if (arrayAttr.size() >
shape.size())
3802 return op.emitOpError(
"expected ")
3803 << attrName <<
" attribute of rank no greater than vector rank";
3810template <
typename OpType>
3814 bool halfOpen =
true) {
3815 for (
auto attr : arrayAttr) {
3816 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3820 if (val < min || val >= upper)
3821 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3822 <<
min <<
", " << upper <<
")";
3830template <
typename OpType>
3835 for (
auto [
index, attrDimPair] :
3836 llvm::enumerate(llvm::zip_first(arrayAttr,
shape))) {
3837 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3841 if (val < min || val >=
max)
3842 return op.emitOpError(
"expected ")
3843 << attrName <<
" dimension " <<
index <<
" to be confined to ["
3844 <<
min <<
", " <<
max <<
")";
3854template <
typename OpType>
3859 assert(arrayAttr1.size() <=
shape.size());
3860 assert(arrayAttr2.size() <=
shape.size());
3861 for (
auto [
index, it] :
3862 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2,
shape))) {
3863 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3864 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3868 if (val1 + val2 < 0 || val1 + val2 >=
max)
3869 return op.emitOpError(
"expected sum(")
3870 << attrName1 <<
", " << attrName2 <<
") dimension " <<
index
3871 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3879 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3881 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3884LogicalResult InsertStridedSliceOp::verify() {
3885 auto sourceVectorType = getSourceVectorType();
3886 auto destVectorType = getDestVectorType();
3887 auto offsets = getOffsetsAttr();
3888 auto strides = getStridesAttr();
3889 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3891 "expected offsets of same size as destination vector rank");
3892 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3893 return emitOpError(
"expected strides of same size as source vector rank");
3894 if (sourceVectorType.getRank() > destVectorType.getRank())
3896 "expected source rank to be no greater than destination rank");
3898 auto sourceShape = sourceVectorType.getShape();
3899 auto destShape = destVectorType.getShape();
3900 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3901 destShape.size() - sourceShape.size(), 0);
3902 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3903 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3904 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3913 offName,
"source vector shape",
3917 unsigned rankDiff = destShape.size() - sourceShape.size();
3918 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3919 if (sourceVectorType.getScalableDims()[idx] !=
3920 destVectorType.getScalableDims()[idx + rankDiff]) {
3921 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3924 if (sourceVectorType.getScalableDims()[idx]) {
3925 auto sourceSize = sourceShape[idx];
3926 auto destSize = destShape[idx + rankDiff];
3927 if (sourceSize != destSize) {
3930 << (
" to match the corresponding base size from the input "
3932 << sourceSize << (
" vs ") << destSize << (
")");
3942class FoldInsertStridedSliceSplat final
3943 :
public OpRewritePattern<InsertStridedSliceOp> {
3947 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3948 PatternRewriter &rewriter)
const override {
3950 auto dst = insertStridedSliceOp.getDest();
3951 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3952 if (!splat || getScalarSplatSource(dst) != splat)
3955 rewriter.
replaceOp(insertStridedSliceOp, dst);
3962class FoldInsertStridedSliceOfExtract final
3963 :
public OpRewritePattern<InsertStridedSliceOp> {
3967 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3968 PatternRewriter &rewriter)
const override {
3969 auto extractStridedSliceOp =
3970 insertStridedSliceOp.getValueToStore()
3971 .getDefiningOp<vector::ExtractStridedSliceOp>();
3973 if (!extractStridedSliceOp)
3976 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3980 if (extractStridedSliceOp.getStrides() !=
3981 insertStridedSliceOp.getStrides() ||
3982 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3985 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3992class InsertStridedSliceConstantFolder final
3993 :
public OpRewritePattern<InsertStridedSliceOp> {
3999 static constexpr int64_t vectorSizeFoldThreshold = 256;
4001 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4002 PatternRewriter &rewriter)
const override {
4006 Attribute vectorDestCst;
4010 VectorType destTy = destVector.getType();
4011 if (destTy.isScalable())
4015 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4016 !destVector.hasOneUse())
4020 Attribute sourceCst;
4025 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
4029 if (op.hasNonUnitStrides())
4032 VectorType sliceVecTy = sourceValue.getType();
4033 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4034 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4035 SmallVector<int64_t, 4> offsets =
getI64SubArray(op.getOffsets());
4036 SmallVector<int64_t, 4> destStrides =
computeStrides(destTy.getShape());
4044 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4045 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4046 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4047 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4048 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4049 MutableArrayRef<int64_t> currSlicePosition(
4050 currDestPosition.begin() + rankDifference, currDestPosition.end());
4051 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4054 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
4055 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
4056 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4057 "Invalid slice element");
4058 newValues[linearizedPosition] = *sliceValuesIt;
4071void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4072 RewritePatternSet &results, MLIRContext *context) {
4073 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4074 InsertStridedSliceConstantFolder>(context);
4077OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4078 if (getSourceVectorType() == getDestVectorType())
4079 return getValueToStore();
4088void OuterProductOp::build(OpBuilder &builder, OperationState &
result,
4089 Value
lhs, Value
rhs, Value acc) {
4094void OuterProductOp::print(OpAsmPrinter &p) {
4095 p <<
" " << getLhs() <<
", " << getRhs();
4097 p <<
", " << getAcc();
4100 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
4103ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &
result) {
4104 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4111 if (operandsInfo.size() < 2)
4113 "expected at least 2 operands");
4114 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4115 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4118 "expected vector type for operand #1");
4122 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4123 vRHS.getScalableDims()[0]};
4124 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4125 vLHS.getElementType(), scalableDimsRes);
4128 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4129 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4133 if (!
result.attributes.get(OuterProductOp::getKindAttrName(
result.name))) {
4134 result.attributes.append(
4135 OuterProductOp::getKindAttrName(
result.name),
4136 CombiningKindAttr::get(
result.getContext(),
4137 OuterProductOp::getDefaultKind()));
4143 (operandsInfo.size() > 2 &&
4148LogicalResult OuterProductOp::verify() {
4149 Type tRHS = getOperandTypeRHS();
4150 VectorType vLHS = getOperandVectorTypeLHS(),
4151 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4152 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4154 if (vLHS.getRank() != 1)
4155 return emitOpError(
"expected 1-d vector for operand #1");
4159 if (vRHS.getRank() != 1)
4160 return emitOpError(
"expected 1-d vector for operand #2");
4161 if (vRES.getRank() != 2)
4163 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4164 return emitOpError(
"expected #1 operand dim to match result dim #1");
4165 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4166 return emitOpError(
"expected #2 operand dim to match result dim #2");
4167 if (vLHS.isScalable() && !vRHS.isScalable()) {
4171 "expected either both or only #2 operand dim to be scalable");
4175 if (vRES.getRank() != 1)
4177 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4178 return emitOpError(
"expected #1 operand dim to match result dim #1");
4181 if (vACC && vACC != vRES)
4182 return emitOpError(
"expected operand #3 of same type as result type");
4184 if (!getKindAttr()) {
4185 return emitOpError(
"expected 'kind' attribute of type CombiningKind (e.g. "
4186 "'vector.kind<add>')");
4191 return emitOpError(
"unsupported outerproduct type");
4200Type OuterProductOp::getExpectedMaskType() {
4201 auto vecType = this->getResultVectorType();
4202 return VectorType::get(vecType.getShape(),
4203 IntegerType::get(vecType.getContext(), 1),
4204 vecType.getScalableDims());
4218 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4220 shape.reserve(vectorType.getRank());
4222 for (
unsigned e = offsets.size(); idx < e; ++idx)
4223 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4224 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4225 shape.push_back(vectorType.getShape()[idx]);
4227 return VectorType::get(
shape, vectorType.getElementType(),
4228 vectorType.getScalableDims());
4231void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &
result,
4232 Value source, ArrayRef<int64_t> offsets,
4233 ArrayRef<int64_t> sizes,
4234 ArrayRef<int64_t> strides) {
4235 result.addOperands(source);
4241 offsetsAttr, sizesAttr, stridesAttr));
4242 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(
result.name),
4244 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(
result.name),
4246 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(
result.name),
4250LogicalResult ExtractStridedSliceOp::verify() {
4251 auto type = getSourceVectorType();
4252 auto offsets = getOffsetsAttr();
4253 auto sizes = getSizesAttr();
4254 auto strides = getStridesAttr();
4255 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4257 "expected offsets, sizes and strides attributes of same size");
4259 auto shape = type.getShape();
4260 auto offName = getOffsetsAttrName();
4261 auto sizesName = getSizesAttrName();
4262 auto stridesName = getStridesAttrName();
4278 shape, offName, sizesName,
4283 offsets, sizes, strides);
4284 if (getResult().
getType() != resultType)
4285 return emitOpError(
"expected result type to be ") << resultType;
4287 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
4288 if (type.getScalableDims()[idx]) {
4289 auto inputDim = type.getShape()[idx];
4290 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4291 if (inputDim != inputSize)
4294 << (
" to match the corresponding base size from the input "
4296 << inputSize << (
" vs ") << inputDim << (
")");
4309 auto getElement = [](
ArrayAttr array,
int idx) {
4310 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4312 ArrayAttr extractOffsets = op.getOffsets();
4315 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4317 if (op.getSourceVectorType().getRank() !=
4318 insertOp.getSourceVectorType().getRank())
4320 ArrayAttr insertOffsets = insertOp.getOffsets();
4321 ArrayAttr insertStrides = insertOp.getStrides();
4324 if (extractOffsets.size() > insertOffsets.size())
4326 bool patialoverlap =
false;
4327 bool disjoint =
false;
4329 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4330 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
4332 int64_t start = getElement(insertOffsets, dim);
4333 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4334 int64_t offset = getElement(extractOffsets, dim);
4335 int64_t size = getElement(extractSizes, dim);
4337 if (start <= offset && offset < end) {
4340 if (offset + size > end)
4341 patialoverlap =
true;
4342 offsetDiffs.push_back(offset - start);
4349 if (!disjoint && !patialoverlap) {
4350 op.setOperand(insertOp.getValueToStore());
4353 op.setOffsetsAttr(
b.getI64ArrayAttr(offsetDiffs));
4359 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4374 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4379 if (op.hasNonUnitStrides())
4382 VectorType sourceVecTy = op.getSourceVectorType();
4386 VectorType sliceVecTy = op.getType();
4388 int64_t rank = sliceVecTy.getRank();
4400 const auto denseValuesBegin = dense.value_begin<
Attribute>();
4402 sliceValues.reserve(sliceVecTy.getNumElements());
4406 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4408 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4409 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
4411 assert(
static_cast<int64_t>(sliceValues.size()) ==
4412 sliceVecTy.getNumElements() &&
4413 "Invalid number of slice elements");
4417OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4418 if (getSourceVectorType() == getResult().
getType())
4425 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4432void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4454class StridedSliceFolder final
4455 :
public OpRewritePattern<ExtractStridedSliceOp> {
4457 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4459 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4460 PatternRewriter &rewriter)
const override {
4461 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4465 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4468 SmallVector<int64_t> firstOffsets =
getI64SubArray(firstOp.getOffsets());
4469 SmallVector<int64_t> firstSizes =
getI64SubArray(firstOp.getSizes());
4470 SmallVector<int64_t> secondOffsets =
getI64SubArray(secondOp.getOffsets());
4471 SmallVector<int64_t> secondSizes =
getI64SubArray(secondOp.getSizes());
4473 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4474 SmallVector<int64_t> combinedOffsets(newRank, 0);
4475 SmallVector<int64_t> combinedSizes(newRank);
4476 ArrayRef<int64_t> firstSourceShape =
4477 firstOp.getSourceVectorType().getShape();
4478 for (
unsigned i = 0; i < newRank; ++i) {
4479 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4480 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4481 combinedOffsets[i] = off1 + off2;
4483 if (i < secondSizes.size()) {
4484 combinedSizes[i] = secondSizes[i];
4485 }
else if (i < firstSizes.size()) {
4486 combinedSizes[i] = firstSizes[i];
4488 combinedSizes[i] = firstSourceShape[i];
4492 SmallVector<int64_t> combinedStrides(newRank, 1);
4494 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4512class StridedSliceCreateMaskFolder final
4513 :
public OpRewritePattern<ExtractStridedSliceOp> {
4517 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4518 PatternRewriter &rewriter)
const override {
4519 Location loc = extractStridedSliceOp.getLoc();
4523 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4527 if (extractStridedSliceOp.hasNonUnitStrides())
4530 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4532 SmallVector<int64_t> sliceOffsets;
4535 SmallVector<int64_t> sliceSizes;
4539 SmallVector<Value> sliceMaskDimSizes;
4540 sliceMaskDimSizes.reserve(maskDimSizes.size());
4544 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4545 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4549 IntegerAttr offsetAttr =
4551 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4552 Value sliceMaskDimSize =
4553 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4554 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4559 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4563 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4571class StridedSliceConstantMaskFolder final
4572 :
public OpRewritePattern<ExtractStridedSliceOp> {
4576 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4577 PatternRewriter &rewriter)
const override {
4580 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4581 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4582 if (!constantMaskOp)
4585 if (extractStridedSliceOp.hasNonUnitStrides())
4588 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4590 SmallVector<int64_t> sliceOffsets;
4593 SmallVector<int64_t> sliceSizes;
4597 SmallVector<int64_t> sliceMaskDimSizes;
4598 sliceMaskDimSizes.reserve(maskDimSizes.size());
4599 for (
auto [maskDimSize, sliceOffset, sliceSize] :
4600 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4601 int64_t sliceMaskDimSize = std::max(
4602 static_cast<int64_t
>(0),
4603 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4604 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4607 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4608 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4609 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4612 if (llvm::is_contained(sliceMaskDimSizes, 0))
4613 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4618 extractStridedSliceOp, extractStridedSliceOp.getResult().
getType(),
4626class StridedSliceBroadcast final
4627 :
public OpRewritePattern<ExtractStridedSliceOp> {
4631 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4632 PatternRewriter &rewriter)
const override {
4638 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4639 auto dstVecType = llvm::cast<VectorType>(op.getType());
4640 unsigned dstRank = dstVecType.getRank();
4641 unsigned rankDiff = dstRank - srcRank;
4645 bool needsSlice =
false;
4646 for (
unsigned i = 0; i < srcRank; i++) {
4647 if (srcVecType.getDimSize(i) != 1 &&
4648 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4655 SmallVector<int64_t> offsets =
4657 SmallVector<int64_t> sizes =
4659 for (
unsigned i = 0; i < srcRank; i++) {
4660 if (srcVecType.getDimSize(i) == 1) {
4668 source = ExtractStridedSliceOp::create(
4669 rewriter, op->getLoc(), source, offsets, sizes,
4678class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
4682 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4683 PatternRewriter &rewriter)
const override {
4685 Value splat = getScalarSplatSource(op.getSource());
4709class ContiguousExtractStridedSliceToExtract final
4710 :
public OpRewritePattern<ExtractStridedSliceOp> {
4714 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4715 PatternRewriter &rewriter)
const override {
4716 if (op.hasNonUnitStrides())
4718 Value source = op.getOperand();
4719 auto sourceType = cast<VectorType>(source.
getType());
4720 if (sourceType.isScalable() || sourceType.getRank() == 0)
4729 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4730 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4737 if (numOffsets == 0)
4742 if (numOffsets == sourceType.getRank() &&
4743 static_cast<int>(sizes.size()) == sourceType.getRank())
4747 for (
int i = 0; i < numOffsets; ++i) {
4755 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4756 sizes[numOffsets] == 1) {
4761 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4762 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4771void ExtractStridedSliceOp::getCanonicalizationPatterns(
4772 RewritePatternSet &results, MLIRContext *context) {
4775 results.
add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4776 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4777 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4786void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4787 VectorType vectorType, Value source,
4789 AffineMapAttr permutationMapAttr,
4792 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4794 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4795 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4796 *padding, Value(), inBoundsAttr);
4800void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4801 VectorType vectorType, Value source,
4803 AffineMap permutationMap,
4804 std::optional<ArrayRef<bool>> inBounds) {
4805 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4806 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4809 SmallVector<bool>(vectorType.getRank(),
false));
4810 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4812 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4813 build(builder,
result, vectorType, source,
indices, *padding,
4814 permutationMapAttr, inBoundsAttr);
4818void TransferReadOp::build(OpBuilder &builder, OperationState &
result,
4819 VectorType vectorType, Value source,
4821 std::optional<ArrayRef<bool>> inBounds) {
4823 llvm::cast<ShapedType>(source.
getType()), vectorType);
4824 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4825 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4828 SmallVector<bool>(vectorType.getRank(),
false));
4829 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4831 padding = ub::PoisonOp::create(builder,
result.location, elemType);
4832 build(builder,
result, vectorType, source,
indices, permutationMapAttr,
4834 Value(), inBoundsAttr);
4837template <
typename EmitFun>
4841 for (
auto expr : permutationMap.
getResults()) {
4842 auto dim = dyn_cast<AffineDimExpr>(expr);
4843 auto zero = dyn_cast<AffineConstantExpr>(expr);
4845 if (zero.getValue() != 0) {
4847 "requires a projected permutation_map (at most one dim or the zero "
4848 "constant can appear in each result)");
4853 return emitOpError(
"requires a projected permutation_map (at most one "
4854 "dim or the zero constant can appear in each result)");
4856 if (seen[dim.getPosition()]) {
4858 "requires a permutation_map that is a permutation (found one dim "
4859 "used more than once)");
4861 seen[dim.getPosition()] =
true;
4868 VectorType vectorType, VectorType maskType,
4869 VectorType inferredMaskType,
AffineMap permutationMap,
4871 if (op->hasAttr(
"masked")) {
4872 return op->emitOpError(
"masked attribute has been removed. "
4873 "Use in_bounds instead.");
4876 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4877 return op->emitOpError(
4878 "requires source to be a memref or ranked tensor type");
4880 auto elementType = shapedType.getElementType();
4882 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4884 unsigned sourceVecSize =
4886 vectorElementType.getShape().back();
4887 unsigned resultVecSize =
4889 vectorType.getShape().back();
4890 if (resultVecSize % sourceVecSize != 0)
4891 return op->emitOpError(
4892 "requires the bitwidth of the minor 1-D vector to be an integral "
4893 "multiple of the bitwidth of the minor 1-D vector of the source");
4895 unsigned sourceVecEltRank = vectorElementType.getRank();
4896 unsigned resultVecRank = vectorType.getRank();
4897 if (sourceVecEltRank > resultVecRank)
4898 return op->emitOpError(
4899 "requires source vector element and vector result ranks to match.");
4900 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4903 return op->emitOpError(
"requires a permutation_map with result dims of "
4904 "the same rank as the vector type");
4907 return op->emitOpError(
"does not support masks with vector element type");
4910 unsigned minorSize =
4911 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4912 unsigned resultVecSize =
4915 return op->emitOpError(
4916 "requires the bitwidth of the minor 1-D vector to be an integral "
4917 "multiple of the bitwidth of the source element type");
4921 return op->emitOpError(
"requires a permutation_map with result dims of "
4922 "the same rank as the vector type");
4926 return op->emitOpError(
"requires permutation_map without symbols");
4928 if (permutationMap.
getNumInputs() != shapedType.getRank())
4929 return op->emitOpError(
"requires a permutation_map with input dims of the "
4930 "same rank as the source type");
4932 if (maskType && maskType != inferredMaskType)
4933 return op->emitOpError(
"inferred mask type (")
4934 << inferredMaskType <<
") and mask operand type (" << maskType
4938 return op->emitOpError(
"expects the in_bounds attr of same rank "
4939 "as permutation_map results: ")
4940 << AffineMapAttr::get(permutationMap)
4941 <<
" vs inBounds of size: " << inBounds.size();
4948 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4949 if (op.getPermutationMap().isMinorIdentity())
4950 elidedAttrs.push_back(op.getPermutationMapAttrName());
4952 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4953 elidedAttrs.push_back(op.getInBoundsAttrName());
4957void TransferReadOp::print(OpAsmPrinter &p) {
4960 p <<
", " << getMask();
4967 auto i1Type = IntegerType::get(permMap.
getContext(), 1);
4969 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4974 if (maskShape.empty())
4975 maskShape.push_back(1);
4980 return VectorType::get(maskShape, i1Type, scalableDims);
4997 if (hasMask.succeeded()) {
5004 if (types.size() != 2)
5005 return parser.
emitError(typesLoc,
"requires two types");
5007 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5008 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5009 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5010 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5012 return parser.
emitError(typesLoc,
"requires vector type");
5013 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(
result.name);
5017 if (shapedType.getRank() <
5020 "expected a custom permutation_map when "
5021 "rank(source) != rank(destination)");
5023 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5025 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5027 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(
result.name);
5028 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5029 if (!inBoundsAttr) {
5030 result.addAttribute(inBoundsAttrName,
5039 if (hasMask.succeeded()) {
5040 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5042 maskInfo.
location,
"does not support masks with vector element type");
5045 "expected the same rank for the vector and the "
5046 "results of the permutation map");
5054 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5056 {1, static_cast<int32_t>(indexInfo.size()), 1,
5057 static_cast<int32_t>(hasMask.succeeded())}));
5061LogicalResult TransferReadOp::verify() {
5063 ShapedType shapedType = getShapedType();
5065 VectorType maskType = getMaskType();
5066 auto paddingType = getPadding().getType();
5067 auto permutationMap = getPermutationMap();
5068 VectorType inferredMaskType =
5071 auto sourceElementType = shapedType.getElementType();
5073 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
5074 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5077 shapedType, vectorType, maskType,
5078 inferredMaskType, permutationMap, getInBounds())))
5081 if (
auto sourceVectorElementType =
5082 llvm::dyn_cast<VectorType>(sourceElementType)) {
5085 if (sourceVectorElementType != paddingType)
5087 "requires source element type and padding type to match.");
5091 if (!VectorType::isValidElementType(paddingType))
5092 return emitOpError(
"requires valid padding vector elemental type");
5095 if (paddingType != sourceElementType)
5097 "requires formal padding and source of the same elemental type");
5108Type TransferReadOp::getExpectedMaskType() {
5115VectorType TransferReadOp::getVectorType() {
5116 return cast<VectorType>(getVector().
getType());
5119template <
typename TransferOp>
5123 if (op.getShapedType().isDynamicDim(indicesIdx))
5127 if (!cstOp.has_value())
5130 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5131 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5133 return cstOp.value() + vectorSize <= sourceSize;
5136template <
typename TransferOp>
5140 if (op.getTransferRank() == 0)
5145 newInBounds.reserve(op.getTransferRank());
5150 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
5152 if (op.isDimInBounds(i)) {
5153 newInBounds.push_back(
true);
5158 bool inBounds =
false;
5159 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
5162 dimExpr.getPosition());
5163 nonBcastDims.push_back(i);
5166 newInBounds.push_back(inBounds);
5174 bool allNonBcastDimsInBounds = llvm::all_of(
5175 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
5176 if (allNonBcastDimsInBounds) {
5179 newInBounds[idx] =
true;
5187 op.setInBoundsAttr(
b.getBoolArrayAttr(newInBounds));
5191template <
typename TransferOp>
5193 auto mask = op.getMask();
5200 op.getMaskMutable().clear();
5214static Value foldRAW(TransferReadOp readOp) {
5215 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5217 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5220 return defWrite.getVector();
5222 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5223 cast<VectorTransferOpInterface>(readOp.getOperation())))
5225 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5230OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5231 if (Value vec = foldRAW(*
this))
5242 return OpFoldResult();
5245std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5249void TransferReadOp::getEffects(
5250 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5252 if (llvm::isa<MemRefType>(getShapedType()))
5253 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5254 SideEffects::DefaultResource::get());
5258 if (hasPureTensorSemantics())
5265static AffineMap inverseWithUnusedDims(AffineMap map) {
5267 "expected a projected permutation map");
5272 int64_t pos = cast<AffineDimExpr>(
result).getPosition();
5302struct TransferReadAfterWriteToBroadcast
5303 :
public OpRewritePattern<TransferReadOp> {
5306 LogicalResult matchAndRewrite(TransferReadOp readOp,
5307 PatternRewriter &rewriter)
const override {
5308 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5312 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5316 if (readOp.getMask() || defWrite.getMask())
5319 if (readOp.getIndices() != defWrite.getIndices())
5322 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5326 if (readOp.getTransferChunkAccessed() !=
5327 defWrite.getTransferChunkAccessed())
5334 AffineMap readMap = readOp.getPermutationMap();
5335 AffineMap writeMap = defWrite.getPermutationMap();
5336 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5337 AffineMap composedMap = readMap.
compose(invWriteMap);
5351 int64_t numBroadcastedDims = broadcastedDims.size();
5352 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5354 for (
auto [idx, expr] : llvm::enumerate(composedMap.
getResults())) {
5355 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
5356 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5357 invPerm[effectiveDim] = idx;
5362 VectorType readVecTy = readOp.getVectorType();
5364 auto broadcastedVecTy =
5366 readVecTy.getElementType(),
5369 Value vec = defWrite.getVector();
5370 Location loc = readOp.getLoc();
5371 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5378void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5379 MLIRContext *context) {
5380 results.
add<TransferReadAfterWriteToBroadcast>(context);
5383FailureOr<std::optional<SmallVector<Value>>>
5384TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5385 if (!hasPureBufferSemantics())
5396void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5398 AffineMapAttr permutationMapAttr,
5401 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
5402 build(builder,
result, resultType, vector, dest,
indices, permutationMapAttr,
5403 mask, inBoundsAttr);
5407void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5409 AffineMapAttr permutationMapAttr,
5411 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5412 Value(), inBoundsAttr);
5417void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5419 AffineMap permutationMap,
5420 std::optional<ArrayRef<bool>> inBounds) {
5421 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5423 (inBounds && !inBounds.value().empty())
5426 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
5427 build(builder,
result, vector, dest,
indices, permutationMapAttr,
5428 Value(), inBoundsAttr);
5433void TransferWriteOp::build(OpBuilder &builder, OperationState &
result,
5435 std::optional<ArrayRef<bool>> inBounds) {
5436 auto vectorType = llvm::cast<VectorType>(vector.
getType());
5438 llvm::cast<ShapedType>(dest.
getType()), vectorType);
5439 build(builder,
result, vector, dest,
indices, permutationMap, inBounds);
5442ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5443 OperationState &
result) {
5446 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5447 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5448 SmallVector<Type, 2> types;
5449 OpAsmParser::UnresolvedOperand maskInfo;
5455 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
5460 if (types.size() != 2)
5461 return parser.
emitError(typesLoc,
"requires two types");
5463 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5465 return parser.
emitError(typesLoc,
"requires vector type");
5466 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5467 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5468 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
5469 auto permMapAttrName =
5470 TransferWriteOp::getPermutationMapAttrName(
result.name);
5471 auto permMapAttr =
result.attributes.get(permMapAttrName);
5474 if (shapedType.getRank() <
5477 "expected a custom permutation_map when "
5478 "rank(source) != rank(destination)");
5480 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5482 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5484 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(
result.name);
5485 Attribute inBoundsAttr =
result.attributes.get(inBoundsAttrName);
5486 if (!inBoundsAttr) {
5487 result.addAttribute(inBoundsAttrName,
5495 if (hasMask.succeeded()) {
5496 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5498 maskInfo.
location,
"does not support masks with vector element type");
5501 "expected the same rank for the vector and the "
5502 "results of the permutation map");
5508 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5510 {1, 1, static_cast<int32_t>(indexInfo.size()),
5511 static_cast<int32_t>(hasMask.succeeded())}));
5512 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5516void TransferWriteOp::print(OpAsmPrinter &p) {
5519 p <<
", " << getMask();
5524LogicalResult TransferWriteOp::verify() {
5526 ShapedType shapedType = getShapedType();
5528 VectorType maskType = getMaskType();
5529 auto permutationMap = getPermutationMap();
5530 VectorType inferredMaskType =
5534 if (llvm::size(
getIndices()) != shapedType.getRank())
5535 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
5539 if (hasBroadcastDim())
5540 return emitOpError(
"should not have broadcast dimensions");
5543 shapedType, vectorType, maskType,
5544 inferredMaskType, permutationMap, getInBounds())))
5557Type TransferWriteOp::getExpectedMaskType() {
5564Value TransferWriteOp::getVector() {
return getOperand(0); }
5565VectorType TransferWriteOp::getVectorType() {
5566 return cast<VectorType>(getValueToStore().
getType());
5589static LogicalResult foldReadInitWrite(TransferWriteOp write,
5590 ArrayRef<Attribute>,
5591 SmallVectorImpl<OpFoldResult> &results) {
5593 if (write.getTransferRank() == 0)
5595 auto rankedTensorType =
5596 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5598 if (!rankedTensorType)
5601 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5605 if (read.getTransferRank() == 0)
5608 if (!read.getPermutationMap().isMinorIdentity() ||
5609 !write.getPermutationMap().isMinorIdentity())
5612 if (read.getTransferRank() != write.getTransferRank())
5615 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5618 if (read.getBase().getType() != rankedTensorType)
5621 if (read.getVectorType() != write.getVectorType())
5624 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5627 auto isNotConstantZero = [](Value v) {
5629 return !cstOp.has_value() || cstOp.value() != 0;
5631 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5632 llvm::any_of(write.getIndices(), isNotConstantZero))
5635 results.push_back(read.getBase());
5639static bool checkSameValueWAR(vector::TransferReadOp read,
5640 vector::TransferWriteOp write) {
5641 return read.getBase() == write.getBase() &&
5642 read.getIndices() == write.getIndices() &&
5643 read.getPermutationMap() == write.getPermutationMap() &&
5644 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5661static LogicalResult foldWAR(TransferWriteOp write,
5662 SmallVectorImpl<OpFoldResult> &results) {
5663 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5665 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5669 if (!checkSameValueWAR(read, write))
5671 results.push_back(read.getBase());
5675LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5676 SmallVectorImpl<OpFoldResult> &results) {
5677 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
5679 if (succeeded(foldWAR(*
this, results)))
5691std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5695void TransferWriteOp::getEffects(
5696 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5698 if (llvm::isa<MemRefType>(getShapedType()))
5699 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5700 SideEffects::DefaultResource::get());
5704 if (hasPureTensorSemantics())
5734class FoldWaw final :
public OpRewritePattern<TransferWriteOp> {
5737 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5738 PatternRewriter &rewriter)
const override {
5739 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5741 vector::TransferWriteOp writeToModify = writeOp;
5743 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5747 writeToModify.getBaseMutable().assign(defWrite.getBase());
5752 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5753 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5757 if (!defWrite->hasOneUse())
5759 writeToModify = defWrite;
5760 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5789struct SwapExtractSliceOfTransferWrite
5790 :
public OpRewritePattern<tensor::InsertSliceOp> {
5794 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5795 PatternRewriter &rewriter)
const override {
5796 if (!insertOp.hasUnitStride())
5799 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5800 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5802 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5803 if (!transferOp || !transferOp->hasOneUse())
5808 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5810 "use-def chain is rank-reducing");
5814 if (!extractOp.hasZeroOffset()) {
5816 "ExtractSliceOp has non-zero offset");
5820 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5821 return getConstantIntValue(value) == static_cast<int64_t>(0);
5824 "TranferWriteOp has non-zero offset");
5828 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5830 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5833 for (
auto [insertSize, extractSize] :
5834 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5837 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5842 assert(transferOp.getVectorType().hasStaticShape() &&
5843 "expected vector to have a static shape");
5844 ArrayRef<int64_t>
vectorShape = transferOp.getVectorType().getShape();
5846 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5847 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5849 insertOp,
"TransferWriteOp may not write the full tensor.");
5854 SmallVector<bool> newInBounds(
vectorShape.size(),
false);
5855 auto newExtractOp = tensor::ExtractSliceOp::create(
5856 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5857 insertOp.getDest(), insertOp.getMixedOffsets(),
5858 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5859 auto newTransferWriteOp = TransferWriteOp::create(
5860 rewriter, transferOp.getLoc(), transferOp.getVector(),
5861 newExtractOp.getResult(), transferOp.getIndices(),
5862 transferOp.getPermutationMapAttr(),
5865 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5873void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5874 MLIRContext *context) {
5875 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5878FailureOr<std::optional<SmallVector<Value>>>
5879TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5880 if (!hasPureBufferSemantics())
5890static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5892 MemRefType memRefTy) {
5895 if (!vecTy.isScalable() &&
5896 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5899 if (!memRefTy.isLastDimUnitStride())
5900 return op->
emitOpError(
"most minor memref dim must have unit stride");
5904LogicalResult vector::LoadOp::verify() {
5908 if (
failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5911 if (memRefTy.getRank() < resVecTy.getRank())
5913 "destination memref has lower rank than the result vector");
5916 Type memElemTy = memRefTy.getElementType();
5917 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5918 if (memVecTy != resVecTy)
5919 return emitOpError(
"base memref and result vector types should match");
5920 memElemTy = memVecTy.getElementType();
5923 if (resVecTy.getElementType() != memElemTy)
5924 return emitOpError(
"base and result element types should match");
5925 if (llvm::size(
getIndices()) != memRefTy.getRank())
5926 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5930OpFoldResult LoadOp::fold(FoldAdaptor) {
5933 return OpFoldResult();
5936std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5940FailureOr<std::optional<SmallVector<Value>>>
5941LoadOp::bubbleDownCasts(OpBuilder &builder) {
5950LogicalResult vector::StoreOp::verify() {
5954 if (
failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5957 if (memRefTy.getRank() < valueVecTy.getRank())
5958 return emitOpError(
"source memref has lower rank than the vector to store");
5961 Type memElemTy = memRefTy.getElementType();
5962 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5963 if (memVecTy != valueVecTy)
5965 "base memref and valueToStore vector types should match");
5966 memElemTy = memVecTy.getElementType();
5969 if (valueVecTy.getElementType() != memElemTy)
5970 return emitOpError(
"base and valueToStore element type should match");
5971 if (llvm::size(
getIndices()) != memRefTy.getRank())
5972 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5976LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5977 SmallVectorImpl<OpFoldResult> &results) {
5981std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5985FailureOr<std::optional<SmallVector<Value>>>
5986StoreOp::bubbleDownCasts(OpBuilder &builder) {
5995LogicalResult MaskedLoadOp::verify() {
5996 VectorType maskVType = getMaskVectorType();
5997 VectorType passVType = getPassThruVectorType();
6004 if (llvm::size(
getIndices()) != memType.getRank())
6005 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6006 if (resVType.getShape() != maskVType.getShape())
6007 return emitOpError(
"expected result shape to match mask shape");
6008 if (resVType != passVType)
6009 return emitOpError(
"expected pass_thru of same type as result type");
6014class MaskedLoadFolder final :
public OpRewritePattern<MaskedLoadOp> {
6017 LogicalResult matchAndRewrite(MaskedLoadOp
load,
6018 PatternRewriter &rewriter)
const override {
6030 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
6035void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6036 MLIRContext *context) {
6037 results.
add<MaskedLoadFolder>(context);
6040OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6043 return OpFoldResult();
6046FailureOr<std::optional<SmallVector<Value>>>
6047MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6056LogicalResult MaskedStoreOp::verify() {
6057 VectorType maskVType = getMaskVectorType();
6064 if (llvm::size(
getIndices()) != memType.getRank())
6065 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6066 if (valueVType.getShape() != maskVType.getShape())
6067 return emitOpError(
"expected valueToStore shape to match mask shape");
6072class MaskedStoreFolder final :
public OpRewritePattern<MaskedStoreOp> {
6075 LogicalResult matchAndRewrite(MaskedStoreOp store,
6076 PatternRewriter &rewriter)
const override {
6080 store, store.getValueToStore(), store.getBase(), store.getIndices());
6088 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
6093void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6094 MLIRContext *context) {
6095 results.
add<MaskedStoreFolder>(context);
6098LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6099 SmallVectorImpl<OpFoldResult> &results) {
6103FailureOr<std::optional<SmallVector<Value>>>
6104MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6113LogicalResult GatherOp::verify() {
6114 VectorType indVType = getIndexVectorType();
6115 VectorType maskVType = getMaskVectorType();
6117 ShapedType baseType = getBaseType();
6119 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6120 return emitOpError(
"requires base to be a memref or ranked tensor type");
6125 if (llvm::size(getOffsets()) != baseType.getRank())
6126 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6127 if (resVType.getShape() != indVType.getShape())
6128 return emitOpError(
"expected result dim to match indices dim");
6129 if (resVType.getShape() != maskVType.getShape())
6130 return emitOpError(
"expected result dim to match mask dim");
6131 if (resVType != getPassThruVectorType())
6132 return emitOpError(
"expected pass_thru of same type as result type");
6140Type GatherOp::getExpectedMaskType() {
6141 auto vecType = this->getIndexVectorType();
6142 return VectorType::get(vecType.getShape(),
6143 IntegerType::get(vecType.getContext(), 1),
6144 vecType.getScalableDims());
6147std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6152static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6153 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
6154 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6160 DenseIntElementsAttr elements;
6165 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6169class GatherFolder final :
public OpRewritePattern<GatherOp> {
6172 LogicalResult matchAndRewrite(GatherOp gather,
6173 PatternRewriter &rewriter)
const override {
6178 rewriter.
replaceOp(gather, gather.getPassThru());
6183 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
6189class FoldContiguousGather final :
public OpRewritePattern<GatherOp> {
6192 LogicalResult matchAndRewrite(GatherOp op,
6193 PatternRewriter &rewriter)
const override {
6194 if (!isa<MemRefType>(op.getBase().getType()))
6197 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6201 op.getOffsets(), op.getMask(),
6208void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6209 MLIRContext *context) {
6210 results.
add<GatherFolder, FoldContiguousGather>(context);
6213FailureOr<std::optional<SmallVector<Value>>>
6214GatherOp::bubbleDownCasts(OpBuilder &builder) {
6223LogicalResult ScatterOp::verify() {
6224 VectorType indVType = getIndexVectorType();
6225 VectorType maskVType = getMaskVectorType();
6227 ShapedType baseType = getBaseType();
6229 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6230 return emitOpError(
"requires base to be a memref or ranked tensor type");
6235 if (llvm::size(getOffsets()) != baseType.getRank())
6236 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
6237 if (valueVType.getShape() != indVType.getShape())
6238 return emitOpError(
"expected valueToStore dim to match indices dim");
6239 if (valueVType.getShape() != maskVType.getShape())
6240 return emitOpError(
"expected valueToStore dim to match mask dim");
6244class ScatterFolder final :
public OpRewritePattern<ScatterOp> {
6247 LogicalResult matchAndRewrite(ScatterOp scatter,
6248 PatternRewriter &rewriter)
const override {
6249 ShapedType baseType = scatter.getBaseType();
6250 bool isMemRef = isa<MemRefType>(baseType);
6251 if (!isMemRef && !isa<RankedTensorType>(baseType))
6264 rewriter.
replaceOp(scatter, scatter.getBase());
6269 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
6275class FoldContiguousScatter final :
public OpRewritePattern<ScatterOp> {
6278 LogicalResult matchAndRewrite(ScatterOp op,
6279 PatternRewriter &rewriter)
const override {
6282 if (!isa<MemRefType>(op.getBase().getType()))
6285 if (
failed(isZeroBasedContiguousSeq(op.getIndices())))
6289 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6295void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6296 MLIRContext *context) {
6297 results.
add<ScatterFolder, FoldContiguousScatter>(context);
6300FailureOr<std::optional<SmallVector<Value>>>
6301ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6310LogicalResult ExpandLoadOp::verify() {
6311 VectorType maskVType = getMaskVectorType();
6312 VectorType passVType = getPassThruVectorType();
6319 if (llvm::size(
getIndices()) != memType.getRank())
6320 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6321 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6322 return emitOpError(
"expected result dim to match mask dim");
6323 if (resVType != passVType)
6324 return emitOpError(
"expected pass_thru of same type as result type");
6329class ExpandLoadFolder final :
public OpRewritePattern<ExpandLoadOp> {
6332 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6333 PatternRewriter &rewriter)
const override {
6337 expand, expand.getType(), expand.getBase(), expand.getIndices());
6340 rewriter.
replaceOp(expand, expand.getPassThru());
6345 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
6350void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6351 MLIRContext *context) {
6352 results.
add<ExpandLoadFolder>(context);
6355FailureOr<std::optional<SmallVector<Value>>>
6356ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6365LogicalResult CompressStoreOp::verify() {
6366 VectorType maskVType = getMaskVectorType();
6373 if (llvm::size(
getIndices()) != memType.getRank())
6374 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
6375 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6376 return emitOpError(
"expected valueToStore dim to match mask dim");
6381class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
6384 LogicalResult matchAndRewrite(CompressStoreOp compress,
6385 PatternRewriter &rewriter)
const override {
6389 compress, compress.getValueToStore(), compress.getBase(),
6390 compress.getIndices());
6398 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
6403void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6404 MLIRContext *context) {
6405 results.
add<CompressStoreFolder>(context);
6408FailureOr<std::optional<SmallVector<Value>>>
6409CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6418void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6420 setResultRanges(getResult(), argRanges.front());
6423std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6424 return llvm::to_vector<4>(getResultVectorType().
getShape());
6427LogicalResult ShapeCastOp::verify() {
6429 VectorType sourceType = getSourceVectorType();
6430 VectorType resultType = getResultVectorType();
6438 int64_t sourceNElms = sourceType.getNumElements();
6439 int64_t resultNElms = resultType.getNumElements();
6440 if (sourceNElms != resultNElms) {
6441 return emitOpError() <<
"has different number of elements at source ("
6442 << sourceNElms <<
") and result (" << resultNElms
6447 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6448 int64_t resultNScalableDims = resultType.getNumScalableDims();
6449 if (sourceNScalableDims != resultNScalableDims)
6450 return emitOpError() <<
"has different number of scalable dims at source ("
6451 << sourceNScalableDims <<
") and result ("
6452 << resultNScalableDims <<
")";
6461static bool isOrderPreserving(TransposeOp transpose) {
6462 ArrayRef<int64_t> permutation = transpose.getPermutation();
6463 VectorType sourceType = transpose.getSourceVectorType();
6464 ArrayRef<int64_t> inShape = sourceType.getShape();
6465 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6466 auto isNonScalableUnitDim = [&](int64_t dim) {
6467 return inShape[dim] == 1 && !inDimIsScalable[dim];
6469 int64_t current = 0;
6470 for (
auto p : permutation) {
6471 if (!isNonScalableUnitDim(p)) {
6481OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6483 VectorType resultType =
getType();
6486 if (getSource().
getType() == resultType)
6490 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6491 setOperand(precedingShapeCast.getSource());
6496 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6497 if (isOrderPreserving(transpose)) {
6498 setOperand(transpose.getVector());
6506 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6507 if (bcastOp.getSourceType() == resultType)
6508 return bcastOp.getSource();
6512 if (
auto denseAttr =
6513 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6514 return denseAttr.reshape(
getType());
6517 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6530static VectorType trimTrailingOneDims(VectorType oldType) {
6531 ArrayRef<int64_t> oldShape = oldType.getShape();
6532 ArrayRef<int64_t> newShape = oldShape;
6534 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6535 ArrayRef<bool> newScalableDims = oldScalableDims;
6537 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6538 newShape = newShape.drop_back(1);
6539 newScalableDims = newScalableDims.drop_back(1);
6544 if (newShape.empty()) {
6545 newShape = oldShape.take_back();
6546 newScalableDims = oldScalableDims.take_back();
6549 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6564class ShapeCastCreateMaskFolderTrailingOneDim final
6565 :
public OpRewritePattern<ShapeCastOp> {
6569 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6570 PatternRewriter &rewriter)
const override {
6571 Value shapeOpSrc = shapeOp->getOperand(0);
6572 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
6573 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
6574 if (!createMaskOp && !constantMaskOp)
6577 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6578 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6580 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6581 if (newVecType != shapeOpResTy)
6584 auto numDimsToDrop =
6585 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6592 auto maskOperands = createMaskOp.getOperands();
6593 auto numMaskOperands = maskOperands.size();
6596 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6598 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6599 if (!constant || (constant.value() != 1))
6602 SmallVector<Value> newMaskOperands =
6603 maskOperands.drop_back(numDimsToDrop);
6610 if (constantMaskOp) {
6611 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6612 auto numMaskOperands = maskDimSizes.size();
6615 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6617 if (maskDimSizes[i] != 1)
6621 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6632class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
6636 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6637 PatternRewriter &rewriter)
const override {
6639 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6643 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6644 bool srcIsScalar = !srcVectorType;
6652 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6654 BroadcastableToResult::Success) {
6656 shapeCastOp, dstVectorType, broadcastOp.getSource());
6665void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6666 MLIRContext *context) {
6668 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6676LogicalResult BitCastOp::verify() {
6677 auto sourceVectorType = getSourceVectorType();
6678 auto resultVectorType = getResultVectorType();
6680 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6681 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6682 return emitOpError(
"dimension size mismatch at: ") << i;
6685 DataLayout dataLayout = DataLayout::closest(*
this);
6686 auto sourceElementBits =
6688 auto resultElementBits =
6691 if (sourceVectorType.getRank() == 0) {
6692 if (sourceElementBits != resultElementBits)
6693 return emitOpError(
"source/result bitwidth of the 0-D vector element "
6694 "types must be equal");
6695 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
6696 resultElementBits * resultVectorType.getShape().back()) {
6698 "source/result bitwidth of the minor 1-D vectors must be equal");
6704OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6710 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6711 if (getResult().
getType() == otherOp.getSource().getType())
6712 return otherOp.getSource();
6714 setOperand(otherOp.getSource());
6718 Attribute sourceConstant = adaptor.getSource();
6719 if (!sourceConstant)
6722 Type srcElemType = getSourceVectorType().getElementType();
6723 Type dstElemType = getResultVectorType().getElementType();
6725 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6726 if (floatPack.isSplat()) {
6727 auto splat = floatPack.getSplatValue<FloatAttr>();
6730 if (srcElemType.
isF16() && dstElemType.
isF32()) {
6731 uint32_t bits =
static_cast<uint32_t
>(
6732 splat.getValue().bitcastToAPInt().getZExtValue());
6734 bits = (bits << 16) | (bits & 0xffff);
6735 APInt intBits(32, bits);
6736 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6742 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6743 if (intPack.isSplat()) {
6744 auto splat = intPack.getSplatValue<IntegerAttr>();
6746 if (llvm::isa<IntegerType>(dstElemType)) {
6751 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6752 APInt intBits = splat.getValue().zext(dstBitWidth);
6755 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6756 intBits = (intBits << srcBitWidth) | intBits;
6770static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6771 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6772 SmallVector<int64_t, 8> res(memRefType.getShape());
6774 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6780void TypeCastOp::build(OpBuilder &builder, OperationState &
result,
6782 result.addOperands(source);
6783 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
6784 VectorType vectorType =
6785 VectorType::get(extractShape(memRefType),
6787 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6788 memRefType.getMemorySpace()));
6791LogicalResult TypeCastOp::verify() {
6792 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
6793 if (!canonicalType.getLayout().isIdentity())
6794 return emitOpError(
"expects operand to be a memref with identity layout");
6795 if (!getResultMemRefType().getLayout().isIdentity())
6796 return emitOpError(
"expects result to be a memref with identity layout");
6797 if (getResultMemRefType().getMemorySpace() !=
6799 return emitOpError(
"expects result in same memory space");
6802 auto resultType = getResultMemRefType();
6806 "expects result and operand with same underlying scalar type: ")
6808 if (extractShape(sourceType) != extractShape(resultType))
6810 "expects concatenated result and operand shapes to be equal: ")
6819void vector::TransposeOp::build(OpBuilder &builder, OperationState &
result,
6820 Value vector, ArrayRef<int64_t> permutation) {
6821 VectorType vt = llvm::cast<VectorType>(vector.
getType());
6822 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6823 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6824 for (
unsigned i = 0; i < permutation.size(); ++i) {
6825 transposedShape[i] = vt.getShape()[permutation[i]];
6826 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6829 result.addOperands(vector);
6830 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6831 transposedScalableDims));
6832 result.addAttribute(TransposeOp::getPermutationAttrName(
result.name),
6836OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6839 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6840 return splat.reshape(getResultVectorType());
6843 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6857 if (getSourceVectorType() == getResultVectorType() &&
6858 isOrderPreserving(*
this))
6864LogicalResult vector::TransposeOp::verify() {
6865 VectorType vectorType = getSourceVectorType();
6866 VectorType resultType = getResultVectorType();
6867 int64_t rank = resultType.getRank();
6868 if (vectorType.getRank() != rank)
6869 return emitOpError(
"vector result rank mismatch: ") << rank;
6871 ArrayRef<int64_t> perm = getPermutation();
6872 int64_t size = perm.size();
6874 return emitOpError(
"transposition length mismatch: ") << size;
6875 SmallVector<bool, 8> seen(rank,
false);
6876 for (
const auto &ta : llvm::enumerate(perm)) {
6877 if (ta.value() < 0 || ta.value() >= rank)
6878 return emitOpError(
"transposition index out of range: ") << ta.value();
6879 if (seen[ta.value()])
6880 return emitOpError(
"duplicate position index: ") << ta.value();
6881 seen[ta.value()] =
true;
6882 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6883 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6888std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6889 return llvm::to_vector<4>(getResultVectorType().
getShape());
6892void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6894 setResultRanges(getResult(), argRanges.front());
6900class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6904 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6905 PatternRewriter &rewriter)
const override {
6907 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6908 ArrayRef<int64_t> permutation2) {
6909 SmallVector<int64_t, 4>
result;
6910 for (
auto index : permutation2)
6911 result.push_back(permutation1[index]);
6916 vector::TransposeOp parentTransposeOp =
6917 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6918 if (!parentTransposeOp)
6921 SmallVector<int64_t, 4> permutation = composePermutations(
6922 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6925 transposeOp, transposeOp.getResult().
getType(),
6926 parentTransposeOp.getVector(), permutation);
6932class FoldTransposeSplat final :
public OpRewritePattern<TransposeOp> {
6936 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6937 PatternRewriter &rewriter)
const override {
6938 Value splat = getScalarSplatSource(transposeOp.getVector());
6943 transposeOp, transposeOp.getResultVectorType(), splat);
6949class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6953 LogicalResult matchAndRewrite(TransposeOp transpOp,
6954 PatternRewriter &rewriter)
const override {
6955 Value transposeSrc = transpOp.getVector();
6956 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6957 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6958 if (!createMaskOp && !constantMaskOp)
6963 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6966 auto maskOperands = createMaskOp.getOperands();
6967 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6971 transpOp, transpOp.getResultVectorType(), newOperands);
6976 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6980 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6986class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6990 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6991 PatternRewriter &rewriter)
const override {
6993 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6996 if (!isOrderPreserving(transposeOp))
6999 VectorType resultType = transposeOp.getType();
7006 shapeCastOp.getSource());
7025class FoldTransposeFromElements final :
public OpRewritePattern<TransposeOp> {
7028 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7029 PatternRewriter &rewriter)
const override {
7030 auto fromElementsOp =
7031 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7032 if (!fromElementsOp)
7035 VectorType srcTy = fromElementsOp.getDest().getType();
7036 VectorType dstTy = transposeOp.getType();
7038 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7039 int64_t rank = srcTy.getRank();
7042 SmallVector<int64_t> inversePerm(rank, 0);
7043 for (int64_t i = 0; i < rank; ++i)
7044 inversePerm[permutation[i]] = i;
7046 ArrayRef<int64_t> srcShape = srcTy.getShape();
7047 ArrayRef<int64_t> dstShape = dstTy.getShape();
7048 SmallVector<int64_t> srcIdx(rank, 0);
7049 SmallVector<int64_t> dstIdx(rank, 0);
7053 auto elementsOld = fromElementsOp.getElements();
7054 SmallVector<Value> elementsNew;
7055 int64_t dstNumElements = dstTy.getNumElements();
7056 elementsNew.reserve(dstNumElements);
7060 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7064 for (int64_t j = 0; j < rank; ++j)
7065 srcIdx[j] = dstIdx[inversePerm[j]];
7067 int64_t srcLin =
linearize(srcIdx, srcStrides);
7069 elementsNew.push_back(elementsOld[srcLin]);
7103class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
7106 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7107 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7109 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7110 PatternRewriter &rewriter)
const override {
7116 "not preceded by a broadcast");
7119 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
7120 VectorType outputType = transpose.getResultVectorType();
7123 bool inputIsScalar = !inputType;
7124 if (inputIsScalar) {
7130 ArrayRef<int64_t> permutation = transpose.getPermutation();
7131 ArrayRef<int64_t> inputShape = inputType.getShape();
7132 int64_t inputRank = inputType.getRank();
7133 int64_t outputRank = transpose.getType().getRank();
7134 int64_t deltaRank = outputRank - inputRank;
7137 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7138 bool notOne = inputShape[inputIndex] != 1;
7139 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7140 bool groupEndFound = notOne || prevNotOne;
7141 if (groupEndFound) {
7142 int high = inputIndex + deltaRank;
7146 for (
int i = low; i < high; ++i) {
7147 if (permutation[i] < low || permutation[i] >= high) {
7149 transpose,
"permutation not local to group");
7163 vector::BroadcastableToResult::Success &&
7164 "not broadcastable directly to transpose output");
7175void vector::TransposeOp::getCanonicalizationPatterns(
7176 RewritePatternSet &results, MLIRContext *context) {
7177 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7178 FoldTransposeSplat, FoldTransposeFromElements,
7179 FoldTransposeBroadcast>(context);
7186void ConstantMaskOp::build(OpBuilder &builder, OperationState &
result,
7188 assert(kind == ConstantMaskKind::AllTrue ||
7189 kind == ConstantMaskKind::AllFalse);
7190 build(builder,
result, type,
7191 kind == ConstantMaskKind::AllTrue
7193 : SmallVector<int64_t>(type.getRank(), 0));
7196LogicalResult ConstantMaskOp::verify() {
7197 auto resultType = llvm::cast<VectorType>(getResult().
getType());
7199 if (resultType.getRank() == 0) {
7200 if (getMaskDimSizes().size() != 1)
7201 return emitError(
"array attr must have length 1 for 0-D vectors");
7202 auto dim = getMaskDimSizes()[0];
7203 if (dim != 0 && dim != 1)
7204 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
7209 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
7211 "must specify array attr of size equal vector result rank");
7214 auto resultShape = resultType.getShape();
7215 auto resultScalableDims = resultType.getScalableDims();
7216 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7217 for (
const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7218 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7220 "array attr of size out of bounds of vector result dimension size");
7221 if (resultScalableDims[index] && maskDimSize != 0 &&
7222 maskDimSize != resultShape[index])
7224 "only supports 'none set' or 'all set' scalable dimensions");
7228 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7229 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
7230 if (anyZeros && !allZeros)
7231 return emitOpError(
"expected all mask dim sizes to be zeros, "
7232 "as a result of conjunction with zero mask dim");
7236bool ConstantMaskOp::isAllOnesMask() {
7239 if (resultType.getRank() == 0) {
7240 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
7241 return getMaskDimSizes()[0] == 1;
7243 for (
const auto [resultSize, maskDimSize] :
7244 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7245 if (maskDimSize < resultSize)
7251OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7252 ArrayRef<int64_t> bounds = getMaskDimSizes();
7255 auto createBoolSplat = [&](
bool x) {
7261 if (vectorSizes.empty()) {
7262 assert(bounds.size() == 1 &&
"invalid sizes for zero rank mask");
7263 return createBoolSplat(bounds[0] == 1);
7266 if (bounds == vectorSizes)
7267 return createBoolSplat(
true);
7268 if (llvm::all_of(bounds, [](int64_t x) {
return x == 0; }))
7269 return createBoolSplat(
false);
7270 return OpFoldResult();
7277void CreateMaskOp::build(OpBuilder &builder, OperationState &
result,
7279 ArrayRef<OpFoldResult> mixedOperands) {
7280 SmallVector<Value> operands =
7282 build(builder,
result, type, operands);
7285LogicalResult CreateMaskOp::verify() {
7286 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
7288 if (vectorType.getRank() == 0) {
7289 if (getNumOperands() != 1)
7291 "must specify exactly one operand for 0-D create_mask");
7292 }
else if (getNumOperands() !=
7293 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
7295 "must specify an operand for each result vector dimension");
7325class CreateMaskFolder final :
public OpRewritePattern<CreateMaskOp> {
7329 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7330 PatternRewriter &rewriter)
const override {
7331 VectorType maskType = createMaskOp.getVectorType();
7332 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7333 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7336 constexpr std::array<int64_t, 1> rankZeroShape{1};
7337 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
7338 if (maskType.getRank() == 0) {
7339 maskTypeDimSizes = rankZeroShape;
7340 maskTypeDimScalableFlags = rankZeroScalableDims;
7345 SmallVector<int64_t, 4> constantDims;
7346 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7351 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7353 constantDims.push_back(*intSize);
7357 if (vscaleMultiplier < maskTypeDimSizes[i])
7359 constantDims.push_back(*vscaleMultiplier);
7366 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7367 value = std::clamp<int64_t>(value, 0, maskDimSize);
7370 if (llvm::is_contained(constantDims, 0))
7371 constantDims.assign(constantDims.size(), 0);
7382void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7383 MLIRContext *context) {
7384 results.
add<CreateMaskFolder>(context);
7392 OpBuilder &builder, OperationState &
result, Value mask,
7393 Operation *maskableOp,
7394 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7395 assert(maskRegionBuilder &&
7396 "builder callback for 'maskRegion' must be present");
7398 result.addOperands(mask);
7399 OpBuilder::InsertionGuard guard(builder);
7400 Region *maskRegion =
result.addRegion();
7402 maskRegionBuilder(builder, maskableOp);
7407 Value mask, Operation *maskableOp,
7408 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7409 build(builder,
result, resultTypes, mask, Value(), maskableOp,
7415 Value mask, Value passthru, Operation *maskableOp,
7416 function_ref<
void(OpBuilder &, Operation *)> maskRegionBuilder) {
7417 build(builder,
result, mask, maskableOp, maskRegionBuilder);
7419 result.addOperands(passthru);
7420 result.addTypes(resultTypes);
7423ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &
result) {
7425 result.regions.reserve(1);
7426 Region &maskRegion = *
result.addRegion();
7431 OpAsmParser::UnresolvedOperand mask;
7436 OpAsmParser::UnresolvedOperand passthru;
7438 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
7445 MaskOp::ensureTerminator(maskRegion, builder,
result.location);
7456 SmallVector<Type> resultTypes;
7459 result.types.append(resultTypes);
7465 if (parsePassthru.succeeded()) {
7466 if (resultTypes.empty())
7469 "expects a result if passthru operand is provided");
7478void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7479 p <<
" " << getMask();
7481 p <<
", " << getPassthru();
7485 Block *singleBlock = &getMaskRegion().getBlocks().front();
7492 p <<
" : " << getMask().getType();
7493 if (getNumResults() > 0)
7494 p <<
" -> " << getResultTypes();
7497void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
7500 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7501 MaskOp>::ensureTerminator(region, builder, loc);
7507 if (isa<vector::YieldOp>(block.
back()))
7515 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7516 MaskOp>::ensureTerminator(region, builder, loc);
7522 Operation *maskedOp = &block.
front();
7523 opBuilder.setInsertionPointToEnd(&block);
7524 vector::YieldOp::create(opBuilder, loc, maskedOp->
getResults());
7527LogicalResult MaskOp::verify() {
7529 Block &block = getMaskRegion().getBlocks().
front();
7531 return emitOpError(
"expects a terminator within the mask region");
7534 if (numMaskRegionOps > 2)
7535 return emitOpError(
"expects only one operation to mask");
7538 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
7540 return emitOpError(
"expects a terminator within the mask region");
7542 if (terminator->getNumOperands() != getNumResults())
7544 "expects number of results to match mask region yielded values");
7547 if (numMaskRegionOps == 1)
7550 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
7552 return emitOpError(
"expects a MaskableOpInterface within the mask region");
7556 return emitOpError(
"expects number of results to match maskable operation "
7557 "number of results");
7559 if (!llvm::equal(maskableOp->
getResults(), terminator.getOperands()))
7560 return emitOpError(
"expects all the results from the MaskableOpInterface "
7561 "to match all the values returned by the terminator");
7563 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
7565 "expects result type to match maskable operation result type");
7568 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7569 return emitOpError(
"multiple vector results not supported");
7572 Type expectedMaskType = maskableOp.getExpectedMaskType();
7573 if (getMask().
getType() != expectedMaskType)
7575 << expectedMaskType <<
" mask for the maskable operation";
7578 Value passthru = getPassthru();
7580 if (!maskableOp.supportsPassthru())
7582 "doesn't expect a passthru argument for this maskable operation");
7585 return emitOpError(
"expects result when passthru argument is provided");
7588 return emitOpError(
"expects passthru type to match result type");
7608static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7609 SmallVectorImpl<OpFoldResult> &results) {
7610 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7613 Block *block = maskOp.getMaskBlock();
7614 auto terminator = cast<vector::YieldOp>(block->
front());
7615 if (terminator.getNumOperands() == 0) {
7621 llvm::append_range(results, terminator.getOperands());
7625LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7626 SmallVectorImpl<OpFoldResult> &results) {
7627 if (succeeded(foldEmptyMaskOp(*
this, adaptor, results)))
7635 Operation *maskableOp = getMaskableOp();
7639 llvm::append_range(results, maskableOp->
getResults());
7655class CanonializeEmptyMaskOp :
public OpRewritePattern<MaskOp> {
7658 LogicalResult matchAndRewrite(MaskOp maskOp,
7659 PatternRewriter &rewriter)
const override {
7660 if (!maskOp.isEmpty())
7663 if (!maskOp.hasPassthru())
7666 Block *block = maskOp.getMaskBlock();
7667 auto terminator = cast<vector::YieldOp>(block->
front());
7668 assert(terminator.getNumOperands() == 1 &&
7669 "expected one result when passthru is provided");
7672 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7673 terminator.getOperand(0), maskOp.getPassthru());
7679void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7680 MLIRContext *context) {
7681 results.
add<CanonializeEmptyMaskOp>(context);
7687Operation *MaskOp::getMaskableOp() {
7688 Block *block = getMaskBlock();
7692 return &block->
front();
7696bool MaskOp::hasPassthru() {
return getPassthru() != Value(); }
7702LogicalResult ScanOp::verify() {
7703 VectorType srcType = getSourceType();
7704 VectorType initialType = getInitialValueType();
7706 int64_t srcRank = srcType.getRank();
7707 int64_t reductionDim = getReductionDim();
7708 if (reductionDim >= srcRank)
7710 << reductionDim <<
" has to be less than " << srcRank;
7713 int64_t initialValueRank = initialType.getRank();
7714 if (initialValueRank != srcRank - 1)
7716 << initialValueRank <<
" has to be equal to " << srcRank - 1;
7719 ArrayRef<int64_t> srcShape = srcType.getShape();
7720 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7721 SmallVector<int64_t> expectedShape;
7722 for (
int i = 0; i < srcRank; i++) {
7723 if (i != reductionDim)
7724 expectedShape.push_back(srcShape[i]);
7726 if (!llvm::equal(initialValueShapes, expectedShape)) {
7727 return emitOpError(
"incompatible input/initial value shapes");
7731 Type eltType = getDestType().getElementType();
7734 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
7741 RewritePatternSet &
patterns, PatternBenefit benefit) {
7743 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7744 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7745 StridedSliceConstantMaskFolder, TransposeFolder>(
7750 CombiningKind kind, Value v1, Value acc,
7751 arith::FastMathFlagsAttr fastmath,
7758 case CombiningKind::ADD:
7760 result =
b.createOrFold<arith::AddIOp>(loc, v1, acc);
7761 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7762 result =
b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7764 llvm_unreachable(
"invalid value types for ADD reduction");
7766 case CombiningKind::AND:
7768 result =
b.createOrFold<arith::AndIOp>(loc, v1, acc);
7770 case CombiningKind::MAXNUMF:
7771 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7772 "expected float values");
7773 result =
b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7775 case CombiningKind::MAXIMUMF:
7776 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7777 "expected float values");
7778 result =
b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7780 case CombiningKind::MINNUMF:
7781 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7782 "expected float values");
7783 result =
b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7785 case CombiningKind::MINIMUMF:
7786 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7787 "expected float values");
7788 result =
b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7790 case CombiningKind::MAXSI:
7792 result =
b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7794 case CombiningKind::MINSI:
7796 result =
b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7798 case CombiningKind::MAXUI:
7800 result =
b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7802 case CombiningKind::MINUI:
7804 result =
b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7806 case CombiningKind::MUL:
7808 result =
b.createOrFold<arith::MulIOp>(loc, v1, acc);
7809 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7810 result =
b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7812 llvm_unreachable(
"invalid value types for MUL reduction");
7814 case CombiningKind::OR:
7816 result =
b.createOrFold<arith::OrIOp>(loc, v1, acc);
7818 case CombiningKind::XOR:
7820 result =
b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7824 assert(
result &&
"unknown CombiningKind");
7832void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7834 auto resultType = cast<VectorType>(
getType());
7835 if (resultType.isScalable()) {
7839 APInt zero(bitwidth, 0);
7840 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7841 ConstantIntRanges
result = {zero, high, zero, high};
7842 setResultRanges(getResult(),
result);
7872struct StepCompareFolder :
public OpRewritePattern<StepOp> {
7875 LogicalResult matchAndRewrite(StepOp stepOp,
7876 PatternRewriter &rewriter)
const override {
7877 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7879 for (OpOperand &use : stepOp.getResult().getUses()) {
7880 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7885 const unsigned stepOperandNumber = use.getOperandNumber();
7886 if (stepOperandNumber != 0)
7890 unsigned constOperandNumber = 1;
7891 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7892 std::optional<int64_t> maybeConstValue =
7894 if (!maybeConstValue.has_value())
7897 int64_t constValue = maybeConstValue.value();
7898 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7900 auto maybeSplat = [&]() -> std::optional<bool> {
7902 if ((pred == arith::CmpIPredicate::ult ||
7903 pred == arith::CmpIPredicate::uge) &&
7904 stepSize <= constValue)
7905 return pred == arith::CmpIPredicate::ult;
7908 if ((pred == arith::CmpIPredicate::ule ||
7909 pred == arith::CmpIPredicate::ugt) &&
7910 stepSize - 1 <= constValue) {
7911 return pred == arith::CmpIPredicate::ule;
7915 if ((pred == arith::CmpIPredicate::eq ||
7916 pred == arith::CmpIPredicate::ne) &&
7917 stepSize <= constValue)
7918 return pred == arith::CmpIPredicate::ne;
7920 return std::nullopt;
7923 if (!maybeSplat.has_value())
7928 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7933 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7945void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7946 MLIRContext *context) {
7947 results.
add<StepCompareFolder>(context);
7957 Operation *maskableOp) {
7958 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
7970 Operation *maskableOp, Value mask,
7975 return MaskOp::create(builder, maskableOp->
getLoc(),
7978 return MaskOp::create(builder, maskableOp->
getLoc(),
7991 Value newValue, Value passthru) {
7995 return arith::SelectOp::create(builder, newValue.
getLoc(), newValue.
getType(),
7996 mask, newValue, passthru);
8003#define GET_ATTRDEF_CLASSES
8004#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8006#define GET_OP_CLASSES
8007#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
OpListType & getOperations()
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)