19 #include <type_traits>
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "vector-to-vector"
55 template <
typename IntType>
57 return llvm::to_vector<4>(llvm::map_range(
58 arrayAttr.getAsRange<IntegerAttr>(),
59 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
93 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
96 auto sourceVectorType =
97 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
98 auto resultVectorType =
99 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
100 if (!sourceVectorType || !resultVectorType)
104 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105 shapeCastOp.getSource().getDefiningOp());
106 if (!sourceShapeCastOp)
108 auto operandSourceVectorType =
109 cast<VectorType>(sourceShapeCastOp.getSource().getType());
110 auto operandResultVectorType = sourceShapeCastOp.getType();
113 if (operandSourceVectorType != resultVectorType ||
114 operandResultVectorType != sourceVectorType)
117 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
139 struct MultiReduceToContract
143 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
145 if (reduceOp.getKind() != vector::CombiningKind::ADD)
147 Operation *mulOp = reduceOp.getSource().getDefiningOp();
148 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
155 if (!isReduceDim.value()) {
156 iteratorTypes.push_back(vector::IteratorType::parallel);
159 iteratorTypes.push_back(vector::IteratorType::reduction);
164 0, exprs, reduceOp.getContext());
170 return IteratorTypeAttr::get(rewriter.getContext(), t);
199 struct CombineContractABTranspose final
203 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
206 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
207 Value lhs = contractOp.getLhs();
208 Value rhs = contractOp.getRhs();
210 bool changed =
false;
211 for (
Value *operand : {&lhs, &rhs}) {
213 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
217 transposeOp.getPermutation(), contractOp.getContext());
219 *operand = transposeOp.getVector();
225 contractOp, lhs, rhs, contractOp.getAcc(),
263 struct CombineContractResultTranspose final
267 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
269 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
270 if (!contractOp || !contractOp->hasOneUse())
273 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
278 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
290 auto combinedResMap = resTMap.compose(contractMap);
297 maps.back() = combinedResMap;
300 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
328 struct CombineContractBroadcast
332 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
335 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
336 Value lhs = contractOp.getLhs();
337 Value rhs = contractOp.getRhs();
339 bool changed =
false;
340 for (
Value *operand : {&lhs, &rhs}) {
346 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
348 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
351 broadcast.getResultVectorType().getRank() - srcType.getRank();
352 bool innerDimBroadcast =
false;
355 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
356 rankDiff + dim.index())) {
357 innerDimBroadcast =
true;
360 originalDims.push_back(
365 if (innerDimBroadcast)
370 bool nonUnitDimReductionBroadcast =
false;
371 for (int64_t i = 0; i < rankDiff; ++i) {
372 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
375 nonUnitDimReductionBroadcast =
true;
379 if (nonUnitDimReductionBroadcast)
385 map = broadcastMap.
compose(map);
401 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
402 if (!unusedDimsBitVector.test(i))
403 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
410 bool hasReductionIteratorApplyingOnBothSides =
false;
411 for (
unsigned i = 0; i < iterators.size(); ++i) {
415 hasReductionIteratorApplyingOnBothSides =
true;
419 if (!hasReductionIteratorApplyingOnBothSides)
427 contractOp, lhs, rhs, contractOp.getAcc(),
446 struct ReorderCastOpsOnBroadcast
450 LogicalResult matchAndRewrite(CastOpInterface op,
459 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
460 castResTy = vecTy.clone(castResTy);
463 bcastOp.getSource(), castResTy, op->
getAttrs());
484 struct ReorderElementwiseOpsOnTranspose final
487 LogicalResult matchAndRewrite(
Operation *op,
500 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
502 transposeMaps.push_back(transposeOp.getPermutation());
503 srcType = transposeOp.getSourceVectorType();
508 if (transposeMaps.empty())
513 if (!llvm::all_equal(transposeMaps))
521 auto order = transposeMaps.front();
523 for (
int i = 0, e = order.size(); i < e; ++i)
524 invOrder[order[i]] = i;
527 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
529 srcValues.push_back(transposeOp.getVector());
533 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
534 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
535 operand.getLoc(), vectorType, operand, invOrder));
539 auto vectorType = srcType.clone(
546 transposeMaps.front());
553 return llvm::to_vector<4>(
554 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
555 [](IntegerAttr attr) { return attr.getInt(); }));
567 struct BubbleDownVectorBitCastForExtract
571 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574 if (extractOp.getSourceVectorType().getRank() != 1)
577 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
581 VectorType castSrcType = castOp.getSourceVectorType();
582 VectorType castDstType = castOp.getResultVectorType();
583 assert(castSrcType.getRank() == castDstType.getRank());
588 if (castSrcType.getNumElements() == 1)
593 if (castSrcType.getNumElements() > castDstType.getNumElements())
596 unsigned expandRatio =
597 castDstType.getNumElements() / castSrcType.getNumElements();
600 assert(values[0].is<Attribute>() &&
"Unexpected non-constant index");
601 return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
609 Value packedValue = rewriter.
create<vector::ExtractOp>(
610 loc, castOp.getSource(), index / expandRatio);
613 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
614 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
619 VectorType packedType =
622 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
626 index % expandRatio);
643 struct BubbleDownBitCastForStridedSliceExtract
647 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
653 VectorType castSrcType = castOp.getSourceVectorType();
654 VectorType castDstType = castOp.getResultVectorType();
655 assert(castSrcType.getRank() == castDstType.getRank());
657 int64_t castSrcLastDim = castSrcType.getShape().back();
658 int64_t castDstLastDim = castDstType.getShape().back();
660 if (castSrcLastDim > castDstLastDim)
664 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
665 [](
const APInt &val) { return !val.isOne(); }))
668 unsigned rank = extractOp.getSourceVectorType().getRank();
669 assert(castDstLastDim % castSrcLastDim == 0);
670 int64_t expandRatio = castDstLastDim / castSrcLastDim;
676 ArrayAttr newOffsets = extractOp.getOffsets();
677 if (newOffsets.size() == rank) {
679 if (offsets.back() % expandRatio != 0)
681 offsets.back() = offsets.back() / expandRatio;
686 ArrayAttr newSizes = extractOp.getSizes();
687 if (newSizes.size() == rank) {
689 if (sizes.back() % expandRatio != 0)
691 sizes.back() = sizes.back() / expandRatio;
696 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
697 dims.back() = dims.back() / expandRatio;
698 VectorType newExtractType =
701 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
702 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
703 newSizes, extractOp.getStrides());
706 extractOp, extractOp.getType(), newExtractOp);
722 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
725 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727 VectorType castSrcType = bitcastOp.getSourceVectorType();
728 VectorType castDstType = bitcastOp.getResultVectorType();
731 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
732 castDstType.isScalable())
735 int64_t castSrcLastDim = castSrcType.getShape().back();
736 int64_t castDstLastDim = castDstType.getShape().back();
737 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
739 if (isNumElemsShrink) {
740 assert(castSrcLastDim % castDstLastDim == 0);
741 ratio = castSrcLastDim / castDstLastDim;
743 assert(castDstLastDim % castSrcLastDim == 0);
744 ratio = castDstLastDim / castSrcLastDim;
747 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
752 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
759 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
760 VectorType newCastSrcType =
762 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
763 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
767 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
768 VectorType newCastDstType =
772 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
773 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
777 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
793 struct BubbleUpBitCastForStridedSliceInsert
797 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
799 VectorType castSrcType = bitcastOp.getSourceVectorType();
800 VectorType castDstType = bitcastOp.getResultVectorType();
801 assert(castSrcType.getRank() == castDstType.getRank());
803 if (castSrcType.getRank() == 0)
806 int64_t castSrcLastDim = castSrcType.getShape().back();
807 int64_t castDstLastDim = castDstType.getShape().back();
809 if (castSrcLastDim < castDstLastDim)
812 assert(castSrcLastDim % castDstLastDim == 0);
813 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
816 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
821 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
822 [](
const APInt &val) { return !val.isOne(); }))
825 unsigned rank = insertOp.getSourceVectorType().getRank();
828 if (rank != insertOp.getDestVectorType().getRank())
832 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
833 unsigned destinationWidth =
834 castDstType.getElementType().getIntOrFloatBitWidth();
835 unsigned numElements = destinationWidth / sourceWidth;
836 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
839 ArrayAttr newOffsets = insertOp.getOffsets();
840 assert(newOffsets.size() == rank);
842 if (offsets.back() % shrinkRatio != 0)
844 offsets.back() = offsets.back() / shrinkRatio;
848 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
849 srcDims.back() = srcDims.back() / shrinkRatio;
850 VectorType newCastSrcType =
853 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
854 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
857 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
858 dstDims.back() = dstDims.back() / shrinkRatio;
859 VectorType newCastDstType =
862 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
863 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
866 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
867 insertOp.getStrides());
891 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
896 std::function<
bool(vector::BitCastOp)> controlFn,
900 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
903 if (controlFn && !controlFn(bitcastOp))
906 VectorType castSrcType = bitcastOp.getSourceVectorType();
907 VectorType castDstType = bitcastOp.getResultVectorType();
908 assert(castSrcType.getRank() == castDstType.getRank());
911 if (castSrcType.getRank() != 1)
914 int64_t castSrcLastDim = castSrcType.getShape().back();
915 int64_t castDstLastDim = castDstType.getShape().back();
917 if (castSrcLastDim < castDstLastDim)
920 assert(castSrcLastDim % castDstLastDim == 0);
921 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
923 if (castSrcLastDim == shrinkRatio)
927 Type elemType = castDstType.getElementType();
932 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
936 VectorType newCastDstType =
938 castDstType.getElementType());
940 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
941 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
943 sliceShape, strides);
945 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
946 res = rewriter.
create<InsertStridedSliceOp>(
955 std::function<bool(BitCastOp)> controlFn;
972 struct ReorderElementwiseOpsOnBroadcast final
975 LogicalResult matchAndRewrite(
Operation *op,
989 if (isa<vector::FMAOp>(op)) {
995 if (!lhsBcastOrSplat ||
996 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
998 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1005 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1007 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1008 auto splat = val.getDefiningOp<vector::SplatOp>();
1010 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1020 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1026 lhsBcastOrSplatType, op->
getAttrs());
1031 op, vectorType, elementwiseOp->
getResults());
1047 bool force32BitVectorIndices, int64_t dim,
1056 if (dim == 0 && force32BitVectorIndices) {
1059 }
else if (dim == 0) {
1062 }
else if (force32BitVectorIndices) {
1064 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1067 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1069 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1074 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1079 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1080 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1084 template <
typename ConcreteOp>
1087 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1090 force32BitVectorIndices(enableIndexOpt) {}
1092 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1094 if (!xferOp.hasOutOfBoundsDim())
1097 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1101 VectorType vtp = xferOp.getVectorType();
1108 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1109 Value off = xferOp.getIndices()[lastIndex];
1113 Value mask = rewriter.
create<vector::CreateMaskOp>(
1116 vtp.getScalableDims()),
1118 if (xferOp.getMask()) {
1120 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1124 xferOp.getMaskMutable().assign(mask);
1132 const bool force32BitVectorIndices;
1136 class VectorCreateMaskOpConversion
1139 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1140 bool enableIndexOpt,
1143 force32BitVectorIndices(enableIndexOpt) {}
1145 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1147 auto dstType = op.getType();
1148 if (cast<VectorType>(dstType).isScalable())
1150 int64_t rank = dstType.getRank();
1154 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1155 rank == 0 ? 0 : dstType.getDimSize(0),
1161 const bool force32BitVectorIndices;
1165 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1166 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1171 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1172 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1189 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1191 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1192 if (!vecType || !vecType.getElementType().isInteger(1))
1196 Value cond = selectOp.getCondition();
1197 if (isa<VectorType>(cond.
getType()))
1201 if (vecType.getRank() != 1 || vecType.isScalable())
1205 if (vecType.getShape()[0] != 1)
1208 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1209 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1213 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1214 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1218 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1240 static FailureOr<size_t>
1241 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1247 auto isUnitDim = [](VectorType type,
int dim) {
1248 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1255 int rankDiff = srcType.getRank() - vectorType.getRank();
1256 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1259 int dim = vectorType.getRank() - i - 1;
1260 if (srcStrides[dim + rankDiff] != 1 ||
1261 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1269 class DropInnerMostUnitDimsTransferRead
1273 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1276 if (readOp.getTransferRank() == 0)
1280 if (readOp.getMask())
1283 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1287 if (!readOp.getPermutationMap().isMinorIdentity())
1290 auto targetType = readOp.getVectorType();
1291 if (targetType.getRank() <= 1)
1294 FailureOr<size_t> maybeDimsToDrop =
1295 getTransferFoldableInnerUnitDims(srcType, targetType);
1296 if (failed(maybeDimsToDrop))
1299 size_t dimsToDrop = maybeDimsToDrop.value();
1300 if (dimsToDrop == 0)
1303 auto inBounds = readOp.getInBoundsValues();
1304 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1305 if (llvm::is_contained(droppedInBounds,
false))
1308 auto resultTargetVecType =
1310 targetType.getElementType(),
1311 targetType.getScalableDims().drop_back(dimsToDrop));
1313 auto loc = readOp.getLoc();
1320 auto resultMemrefType =
1321 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1322 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1325 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1326 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1327 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1329 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1330 Value result = rewriter.
create<vector::TransferReadOp>(
1331 loc, resultTargetVecType, rankedReducedView,
1333 readOp.getPadding(),
1335 Value(), inBoundsAttr);
1360 class DropInnerMostUnitDimsTransferWrite
1364 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1367 if (writeOp.getTransferRank() == 0)
1371 if (writeOp.getMask())
1374 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1378 if (!writeOp.getPermutationMap().isMinorIdentity())
1381 auto targetType = writeOp.getVectorType();
1382 if (targetType.getRank() <= 1)
1385 FailureOr<size_t> maybeDimsToDrop =
1386 getTransferFoldableInnerUnitDims(srcType, targetType);
1387 if (failed(maybeDimsToDrop))
1390 size_t dimsToDrop = maybeDimsToDrop.value();
1391 if (dimsToDrop == 0)
1394 auto inBounds = writeOp.getInBoundsValues();
1395 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1396 if (llvm::is_contained(droppedInBounds,
false))
1399 auto resultTargetVecType =
1401 targetType.getElementType(),
1402 targetType.getScalableDims().drop_back(dimsToDrop));
1411 auto resultMemrefType =
1412 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1413 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1416 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1418 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1419 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1421 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1423 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1424 loc, resultTargetVecType, writeOp.getVector());
1426 writeOp, shapeCast, rankedReducedView,
1429 Value(), inBoundsAttr);
1437 struct CanonicalizeContractMatmulToMMT final
1441 using FilterConstraintType =
1442 std::function<LogicalResult(vector::ContractionOp op)>;
1445 FilterConstraintType constraint)
1447 filter(std::move(constraint)) {}
1449 LogicalResult matchAndRewrite(vector::ContractionOp op,
1451 if (failed(filter(op)))
1455 Value lhs = op.getLhs();
1456 Value rhs = op.getRhs();
1457 Value res = op.getAcc();
1461 auto infer = [&](MapList m) {
1468 static constexpr std::array<int64_t, 2> perm = {1, 0};
1469 auto iteratorTypes = op.getIteratorTypes().getValue();
1471 if (iteratorTypes.size() != 3 ||
1478 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1479 if (maps == canonicalForm)
1484 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1485 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1487 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1488 VectorType newType =
1489 cast<VectorType>(trans.
getType())
1490 .
clone(cast<VectorType>(mat.getType()).getElementType());
1491 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1493 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1495 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1496 VectorType newType =
1498 cast<VectorType>(mat.getType()).getElementType());
1499 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1501 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1504 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1505 rhs = createTranspose(rhs);
1506 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1507 lhs = createTranspose(lhs);
1508 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1509 rhs = createTranspose(rhs);
1510 lhs = createTranspose(lhs);
1511 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1512 std::swap(rhs, lhs);
1513 rhs = createTranspose(rhs);
1514 lhs = createTranspose(lhs);
1515 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1516 std::swap(rhs, lhs);
1517 rhs = createTranspose(rhs);
1518 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1519 std::swap(lhs, rhs);
1520 lhs = createTranspose(lhs);
1521 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1522 std::swap(lhs, rhs);
1528 op.getIteratorTypes());
1533 FilterConstraintType filter;
1553 template <
typename ExtOp>
1554 struct FoldArithExtIntoContractionOp
1558 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1561 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1562 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1564 if (!lhsDefOp || !rhsDefOp) {
1566 "no defining op on contract operands");
1570 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1571 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1572 contractOp.getIteratorTypesAttr());
1591 LogicalResult matchAndRewrite(vector::ReductionOp op,
1594 if (op.getKind() != vector::CombiningKind::ADD)
1598 Value acc = op.getAcc();
1605 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1606 if (!parentReduction)
1611 if (isa<IntegerType>(acc.
getType())) {
1613 loc, parentReduction.getVector(), op.getVector());
1615 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1619 parentReduction.getAcc());
1629 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1630 auto inVecShape = inVecTy.getShape();
1633 for (
auto [dim, isScalable] :
1634 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1635 if (dim == 1 && !isScalable)
1638 newShape.push_back(dim);
1639 newScalableDims.push_back(isScalable);
1642 if (newShape.empty()) {
1643 newShape.push_back(1);
1644 newScalableDims.push_back(
false);
1647 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1675 struct DropUnitDimFromElementwiseOps final
1678 LogicalResult matchAndRewrite(
Operation *op,
1684 if (!resultVectorType)
1691 if (!sourceVectorType)
1693 if (sourceVectorType.getRank() < 2)
1699 auto opVectorType = cast<VectorType>(operand.getType());
1700 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1701 if (newVType == opVectorType)
1704 auto opSC = rewriter.
create<vector::ShapeCastOp>(loc, newVType, operand);
1705 newOperands.push_back(opSC);
1708 VectorType newResultVectorType =
1709 dropNonScalableUnitDimFromType(resultVectorType);
1713 newResultVectorType, op->
getAttrs());
1739 LogicalResult matchAndRewrite(vector::ReductionOp op,
1742 if (op.getKind() != vector::CombiningKind::ADD)
1745 Type elemType = op.getSourceVectorType().getElementType();
1748 if (!isa<FloatType>(elemType))
1751 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1761 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
1779 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
1781 unsigned maxNumElementsToExtract,
1784 maxNumElementsToExtract(maxNumElementsToExtract) {}
1786 LogicalResult matchAndRewrite(vector::ReductionOp op,
1788 VectorType type = op.getSourceVectorType();
1789 if (type.isScalable() || op.isMasked())
1791 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
1793 int64_t numElems = type.getNumElements();
1794 if (numElems > maxNumElementsToExtract) {
1796 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
1797 "(max allowed: {1})",
1798 numElems, maxNumElementsToExtract));
1804 extractedElem = rewriter.
create<vector::ExtractOp>(
1805 loc, op.getVector(),
static_cast<int64_t
>(idx));
1807 Value res = extracted.front();
1808 for (
auto extractedElem : llvm::drop_begin(extracted))
1810 extractedElem, op.getFastmathAttr());
1811 if (
Value acc = op.getAcc())
1813 op.getFastmathAttr());
1820 unsigned maxNumElementsToExtract = 0;
1839 template <
typename MulOpType>
1840 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
1844 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
1847 if (!broadcastOp.computeBroadcastedUnitDims().empty())
1850 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1851 return srcType && srcType.getRank() != 2;
1854 LogicalResult matchAndRewrite(MulOpType mulOp,
1856 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1859 if (resType.getRank() != 2)
1864 auto matchOuterProduct =
1866 Value operandB) -> FailureOr<vector::OuterProductOp> {
1867 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
1872 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
1875 auto broadcastedLhs =
1876 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
1877 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
1880 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
1881 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
1884 return rewriter.
create<vector::OuterProductOp>(
1885 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
1886 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
1889 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
1890 auto maybeOuterP = matchOuterProduct(lhs, rhs);
1892 if (failed(maybeOuterP))
1893 maybeOuterP = matchOuterProduct(rhs, lhs);
1894 if (failed(maybeOuterP))
1896 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
1905 patterns.
add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
1906 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
1913 patterns.
add<VectorCreateMaskOpConversion,
1914 MaterializeTransferMask<vector::TransferReadOp>,
1915 MaterializeTransferMask<vector::TransferWriteOp>>(
1916 patterns.
getContext(), force32BitVectorIndices, benefit);
1922 patterns.
add<ShapeCastOpFolder>(patterns.
getContext(), benefit);
1927 patterns.
add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1933 patterns.
add<BubbleDownVectorBitCastForExtract,
1934 BubbleDownBitCastForStridedSliceExtract,
1935 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
1941 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
1943 std::move(controlFn), benefit);
1948 std::function<LogicalResult(vector::ContractionOp)> constraint,
1950 patterns.
add<CanonicalizeContractMatmulToMMT>(patterns.
getContext(), benefit,
1951 std::move(constraint));
1956 patterns.
add<MultiReduceToContract, CombineContractBroadcast,
1957 CombineContractABTranspose, CombineContractResultTranspose,
1958 ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1965 patterns.
add<DropInnerMostUnitDimsTransferRead,
1966 DropInnerMostUnitDimsTransferWrite>(patterns.
getContext(),
1972 patterns.
add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1978 patterns.
add<ChainedReduction>(patterns.
getContext(), benefit);
1986 patterns.
add<BreakDownVectorReduction>(patterns.
getContext(),
1987 maxNumElementsToExtract, benefit);
1992 patterns.
add<FoldArithToVectorOuterProduct<arith::MulFOp>,
1993 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2001 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...