19 #include <type_traits>
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "vector-to-vector"
55 template <
typename IntType>
57 return llvm::to_vector<4>(llvm::map_range(
58 arrayAttr.getAsRange<IntegerAttr>(),
59 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
93 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
96 auto sourceVectorType =
97 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
98 auto resultVectorType =
99 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
100 if (!sourceVectorType || !resultVectorType)
104 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105 shapeCastOp.getSource().getDefiningOp());
106 if (!sourceShapeCastOp)
108 auto operandSourceVectorType =
109 cast<VectorType>(sourceShapeCastOp.getSource().getType());
110 auto operandResultVectorType = sourceShapeCastOp.getType();
113 if (operandSourceVectorType != resultVectorType ||
114 operandResultVectorType != sourceVectorType)
117 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
139 struct MultiReduceToContract
143 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
145 if (reduceOp.getKind() != vector::CombiningKind::ADD)
147 Operation *mulOp = reduceOp.getSource().getDefiningOp();
148 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
155 if (!isReduceDim.value()) {
156 iteratorTypes.push_back(vector::IteratorType::parallel);
159 iteratorTypes.push_back(vector::IteratorType::reduction);
164 0, exprs, reduceOp.getContext());
170 return IteratorTypeAttr::get(rewriter.getContext(), t);
199 struct CombineContractABTranspose final
203 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
206 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
207 Value lhs = contractOp.getLhs();
208 Value rhs = contractOp.getRhs();
211 for (
Value *operand : {&lhs, &rhs}) {
213 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
217 transposeOp.getPermutation(), contractOp.getContext());
219 *operand = transposeOp.getVector();
225 contractOp, lhs, rhs, contractOp.getAcc(),
263 struct CombineContractResultTranspose final
267 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
269 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
270 if (!contractOp || !contractOp->hasOneUse())
273 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
278 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
290 auto combinedResMap = resTMap.compose(contractMap);
297 maps.back() = combinedResMap;
300 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
328 struct CombineContractBroadcast
332 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
335 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
336 Value lhs = contractOp.getLhs();
337 Value rhs = contractOp.getRhs();
340 for (
Value *operand : {&lhs, &rhs}) {
346 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
348 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
351 broadcast.getResultVectorType().getRank() - srcType.getRank();
352 bool innerDimBroadcast =
false;
355 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
356 rankDiff + dim.index())) {
357 innerDimBroadcast =
true;
360 originalDims.push_back(
365 if (innerDimBroadcast)
370 bool nonUnitDimReductionBroadcast =
false;
371 for (int64_t i = 0; i < rankDiff; ++i) {
372 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
375 nonUnitDimReductionBroadcast =
true;
379 if (nonUnitDimReductionBroadcast)
385 map = broadcastMap.
compose(map);
401 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
402 if (!unusedDimsBitVector.test(i))
403 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
410 bool hasReductionIteratorApplyingOnBothSides =
false;
411 for (
unsigned i = 0; i < iterators.size(); ++i) {
415 hasReductionIteratorApplyingOnBothSides =
true;
419 if (!hasReductionIteratorApplyingOnBothSides)
427 contractOp, lhs, rhs, contractOp.getAcc(),
446 struct ReorderCastOpsOnBroadcast
450 LogicalResult matchAndRewrite(CastOpInterface op,
452 if (op->getNumOperands() != 1)
454 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
459 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
460 castResTy = vecTy.clone(castResTy);
462 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
463 bcastOp.getSource(), castResTy, op->getAttrs());
465 op, op->getResult(0).getType(), castOp->getResult(0));
484 struct ReorderElementwiseOpsOnTranspose final
487 LogicalResult matchAndRewrite(
Operation *op,
500 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
502 transposeMaps.push_back(transposeOp.getPermutation());
503 srcType = transposeOp.getSourceVectorType();
508 if (transposeMaps.empty())
513 if (!llvm::all_equal(transposeMaps))
521 auto order = transposeMaps.front();
523 for (
int i = 0, e = order.size(); i < e; ++i)
524 invOrder[order[i]] = i;
527 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
529 srcValues.push_back(transposeOp.getVector());
533 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
534 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
535 operand.getLoc(), vectorType, operand, invOrder));
539 auto vectorType = srcType.clone(
546 transposeMaps.front());
553 return llvm::to_vector<4>(
554 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
555 [](IntegerAttr attr) { return attr.getInt(); }));
567 struct BubbleDownVectorBitCastForExtract
571 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574 if (extractOp.getSourceVectorType().getRank() != 1)
577 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
581 VectorType castSrcType = castOp.getSourceVectorType();
582 VectorType castDstType = castOp.getResultVectorType();
583 assert(castSrcType.getRank() == castDstType.getRank());
588 if (castSrcType.getNumElements() == 1)
593 if (castSrcType.getNumElements() > castDstType.getNumElements())
596 unsigned expandRatio =
597 castDstType.getNumElements() / castSrcType.getNumElements();
600 auto mixedPos = extractOp.getMixedPosition();
601 if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
603 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
608 Value packedValue = rewriter.
create<vector::ExtractOp>(
609 loc, castOp.getSource(), index / expandRatio);
612 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
613 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
618 VectorType packedType =
621 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
625 index % expandRatio);
642 struct BubbleDownBitCastForStridedSliceExtract
646 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
648 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
652 VectorType castSrcType = castOp.getSourceVectorType();
653 VectorType castDstType = castOp.getResultVectorType();
654 assert(castSrcType.getRank() == castDstType.getRank());
656 int64_t castSrcLastDim = castSrcType.getShape().back();
657 int64_t castDstLastDim = castDstType.getShape().back();
659 if (castSrcLastDim > castDstLastDim)
663 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
664 [](
const APInt &val) { return !val.isOne(); }))
667 unsigned rank = extractOp.getSourceVectorType().getRank();
668 assert(castDstLastDim % castSrcLastDim == 0);
669 int64_t expandRatio = castDstLastDim / castSrcLastDim;
675 ArrayAttr newOffsets = extractOp.getOffsets();
676 if (newOffsets.size() == rank) {
678 if (offsets.back() % expandRatio != 0)
680 offsets.back() = offsets.back() / expandRatio;
685 ArrayAttr newSizes = extractOp.getSizes();
686 if (newSizes.size() == rank) {
688 if (sizes.back() % expandRatio != 0)
690 sizes.back() = sizes.back() / expandRatio;
695 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
696 dims.back() = dims.back() / expandRatio;
697 VectorType newExtractType =
700 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
701 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
702 newSizes, extractOp.getStrides());
705 extractOp, extractOp.getType(), newExtractOp);
721 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
724 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
726 VectorType castSrcType = bitcastOp.getSourceVectorType();
727 VectorType castDstType = bitcastOp.getResultVectorType();
730 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
731 castDstType.isScalable())
734 int64_t castSrcLastDim = castSrcType.getShape().back();
735 int64_t castDstLastDim = castDstType.getShape().back();
736 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
738 if (isNumElemsShrink) {
739 assert(castSrcLastDim % castDstLastDim == 0);
740 ratio = castSrcLastDim / castDstLastDim;
742 assert(castDstLastDim % castSrcLastDim == 0);
743 ratio = castDstLastDim / castSrcLastDim;
746 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
751 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
758 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
759 VectorType newCastSrcType =
761 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
762 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
766 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
767 VectorType newCastDstType =
771 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
772 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
776 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
792 struct BubbleUpBitCastForStridedSliceInsert
796 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
798 VectorType castSrcType = bitcastOp.getSourceVectorType();
799 VectorType castDstType = bitcastOp.getResultVectorType();
800 assert(castSrcType.getRank() == castDstType.getRank());
802 if (castSrcType.getRank() == 0)
805 int64_t castSrcLastDim = castSrcType.getShape().back();
806 int64_t castDstLastDim = castDstType.getShape().back();
808 if (castSrcLastDim < castDstLastDim)
811 assert(castSrcLastDim % castDstLastDim == 0);
812 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
815 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
820 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
821 [](
const APInt &val) { return !val.isOne(); }))
824 unsigned rank = insertOp.getSourceVectorType().getRank();
827 if (rank != insertOp.getDestVectorType().getRank())
831 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
832 unsigned destinationWidth =
833 castDstType.getElementType().getIntOrFloatBitWidth();
834 unsigned numElements = destinationWidth / sourceWidth;
835 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
838 ArrayAttr newOffsets = insertOp.getOffsets();
839 assert(newOffsets.size() == rank);
841 if (offsets.back() % shrinkRatio != 0)
843 offsets.back() = offsets.back() / shrinkRatio;
847 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
848 srcDims.back() = srcDims.back() / shrinkRatio;
849 VectorType newCastSrcType =
852 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
853 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
856 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
857 dstDims.back() = dstDims.back() / shrinkRatio;
858 VectorType newCastDstType =
861 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
862 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
865 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
866 insertOp.getStrides());
890 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
895 std::function<
bool(vector::BitCastOp)> controlFn,
899 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
902 if (controlFn && !controlFn(bitcastOp))
905 VectorType castSrcType = bitcastOp.getSourceVectorType();
906 VectorType castDstType = bitcastOp.getResultVectorType();
907 assert(castSrcType.getRank() == castDstType.getRank());
912 if (castSrcType.isScalable())
914 "Scalable vectors are not supported");
917 if (castSrcType.getRank() != 1)
920 int64_t castSrcLastDim = castSrcType.getShape().back();
921 int64_t castDstLastDim = castDstType.getShape().back();
923 if (castSrcLastDim < castDstLastDim)
926 assert(castSrcLastDim % castDstLastDim == 0);
927 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
929 if (castSrcLastDim == shrinkRatio)
933 Type elemType = castDstType.getElementType();
938 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
942 VectorType newCastDstType =
944 castDstType.getElementType());
946 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
947 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
949 sliceShape, strides);
951 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
952 res = rewriter.
create<InsertStridedSliceOp>(
961 std::function<bool(BitCastOp)> controlFn;
978 struct ReorderElementwiseOpsOnBroadcast final
981 LogicalResult matchAndRewrite(
Operation *op,
989 op,
"Op doesn't have ElementwiseMappableTraits");
994 "result and operand type mismatch");
995 if (isa<vector::FMAOp>(op)) {
998 "Op only accepts vector types - not supported as broadcast source "
999 "might be a scalar");
1004 if (!lhsBcastOrSplat ||
1005 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1007 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1014 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1016 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1017 auto splat = val.getDefiningOp<vector::SplatOp>();
1019 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1029 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1035 lhsBcastOrSplatType, op->
getAttrs());
1040 op, vectorType, elementwiseOp->
getResults());
1056 bool force32BitVectorIndices, int64_t dim,
1065 if (dim == 0 && force32BitVectorIndices) {
1068 }
else if (dim == 0) {
1071 }
else if (force32BitVectorIndices) {
1073 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1076 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1078 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1083 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1088 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1089 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1093 template <
typename ConcreteOp>
1096 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1099 force32BitVectorIndices(enableIndexOpt) {}
1101 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1103 if (!xferOp.hasOutOfBoundsDim())
1106 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1110 VectorType vtp = xferOp.getVectorType();
1117 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1118 Value off = xferOp.getIndices()[lastIndex];
1122 Value mask = rewriter.
create<vector::CreateMaskOp>(
1125 vtp.getScalableDims()),
1127 if (xferOp.getMask()) {
1129 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1133 xferOp.getMaskMutable().assign(mask);
1141 const bool force32BitVectorIndices;
1145 class VectorCreateMaskOpConversion
1148 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1149 bool enableIndexOpt,
1152 force32BitVectorIndices(enableIndexOpt) {}
1154 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1156 auto dstType = op.getType();
1157 if (cast<VectorType>(dstType).isScalable())
1159 int64_t rank = dstType.getRank();
1163 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1164 rank == 0 ? 0 : dstType.getDimSize(0),
1170 const bool force32BitVectorIndices;
1174 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1175 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1180 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1181 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1198 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1200 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1201 if (!vecType || !vecType.getElementType().isInteger(1))
1205 Value cond = selectOp.getCondition();
1206 if (isa<VectorType>(cond.
getType()))
1210 if (vecType.getRank() != 1 || vecType.isScalable())
1214 if (vecType.getShape()[0] != 1)
1217 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1218 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1222 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1223 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1227 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1249 static FailureOr<size_t>
1250 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1253 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1256 auto isUnitDim = [](VectorType type,
int dim) {
1257 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1264 int rankDiff = srcType.getRank() - vectorType.getRank();
1265 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1268 int dim = vectorType.getRank() - i - 1;
1269 if (srcStrides[dim + rankDiff] != 1 ||
1270 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1278 class DropInnerMostUnitDimsTransferRead
1282 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1285 if (readOp.getTransferRank() == 0)
1289 if (readOp.getMask())
1292 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1296 if (!readOp.getPermutationMap().isMinorIdentity())
1299 auto targetType = readOp.getVectorType();
1300 if (targetType.getRank() <= 1)
1303 FailureOr<size_t> maybeDimsToDrop =
1304 getTransferFoldableInnerUnitDims(srcType, targetType);
1305 if (failed(maybeDimsToDrop))
1308 size_t dimsToDrop = maybeDimsToDrop.value();
1309 if (dimsToDrop == 0)
1312 auto inBounds = readOp.getInBoundsValues();
1313 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1314 if (llvm::is_contained(droppedInBounds,
false))
1317 auto resultTargetVecType =
1319 targetType.getElementType(),
1320 targetType.getScalableDims().drop_back(dimsToDrop));
1322 auto loc = readOp.getLoc();
1329 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1330 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1333 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1334 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1335 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1337 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1338 Value result = rewriter.
create<vector::TransferReadOp>(
1339 loc, resultTargetVecType, rankedReducedView,
1341 readOp.getPadding(),
1343 Value(), inBoundsAttr);
1368 class DropInnerMostUnitDimsTransferWrite
1372 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1375 if (writeOp.getTransferRank() == 0)
1379 if (writeOp.getMask())
1382 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1386 if (!writeOp.getPermutationMap().isMinorIdentity())
1389 auto targetType = writeOp.getVectorType();
1390 if (targetType.getRank() <= 1)
1393 FailureOr<size_t> maybeDimsToDrop =
1394 getTransferFoldableInnerUnitDims(srcType, targetType);
1395 if (failed(maybeDimsToDrop))
1398 size_t dimsToDrop = maybeDimsToDrop.value();
1399 if (dimsToDrop == 0)
1402 auto inBounds = writeOp.getInBoundsValues();
1403 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1404 if (llvm::is_contained(droppedInBounds,
false))
1407 auto resultTargetVecType =
1409 targetType.getElementType(),
1410 targetType.getScalableDims().drop_back(dimsToDrop));
1419 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1420 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1423 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1425 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1426 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1428 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1430 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1431 loc, resultTargetVecType, writeOp.getVector());
1433 writeOp, shapeCast, rankedReducedView,
1436 Value(), inBoundsAttr);
1444 struct CanonicalizeContractMatmulToMMT final
1448 using FilterConstraintType =
1449 std::function<LogicalResult(vector::ContractionOp op)>;
1452 FilterConstraintType constraint)
1454 filter(std::move(constraint)) {}
1456 LogicalResult matchAndRewrite(vector::ContractionOp op,
1458 if (failed(filter(op)))
1462 Value lhs = op.getLhs();
1463 Value rhs = op.getRhs();
1464 Value res = op.getAcc();
1468 auto infer = [&](MapList m) {
1475 static constexpr std::array<int64_t, 2> perm = {1, 0};
1476 auto iteratorTypes = op.getIteratorTypes().getValue();
1478 if (iteratorTypes.size() != 3 ||
1485 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1486 if (maps == canonicalForm)
1491 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1492 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1494 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1495 VectorType newType =
1496 cast<VectorType>(trans.
getType())
1497 .
clone(cast<VectorType>(mat.getType()).getElementType());
1498 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1500 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1502 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1503 VectorType newType =
1505 cast<VectorType>(mat.getType()).getElementType());
1506 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1508 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1511 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1512 rhs = createTranspose(rhs);
1513 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1514 lhs = createTranspose(lhs);
1515 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1516 rhs = createTranspose(rhs);
1517 lhs = createTranspose(lhs);
1518 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1519 std::swap(rhs, lhs);
1520 rhs = createTranspose(rhs);
1521 lhs = createTranspose(lhs);
1522 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1523 std::swap(rhs, lhs);
1524 rhs = createTranspose(rhs);
1525 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1526 std::swap(lhs, rhs);
1527 lhs = createTranspose(lhs);
1528 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1529 std::swap(lhs, rhs);
1535 op.getIteratorTypes());
1540 FilterConstraintType filter;
1560 template <
typename ExtOp>
1561 struct FoldArithExtIntoContractionOp
1565 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1568 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1569 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1571 if (!lhsDefOp || !rhsDefOp) {
1573 "no defining op on contract operands");
1577 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1578 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1579 contractOp.getIteratorTypesAttr());
1598 LogicalResult matchAndRewrite(vector::ReductionOp op,
1601 if (op.getKind() != vector::CombiningKind::ADD)
1605 Value acc = op.getAcc();
1612 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1613 if (!parentReduction)
1618 if (isa<IntegerType>(acc.
getType())) {
1620 loc, parentReduction.getVector(), op.getVector());
1622 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1626 parentReduction.getAcc());
1636 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1637 auto inVecShape = inVecTy.getShape();
1640 for (
auto [dim, isScalable] :
1641 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1642 if (dim == 1 && !isScalable)
1645 newShape.push_back(dim);
1646 newScalableDims.push_back(isScalable);
1649 if (newShape.empty()) {
1650 newShape.push_back(1);
1651 newScalableDims.push_back(
false);
1654 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1682 struct DropUnitDimFromElementwiseOps final
1685 LogicalResult matchAndRewrite(
Operation *op,
1691 if (!resultVectorType)
1698 if (!sourceVectorType)
1700 if (sourceVectorType.getRank() < 2)
1706 auto opVectorType = cast<VectorType>(operand.getType());
1707 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1708 if (newVType == opVectorType)
1711 auto opSC = rewriter.
create<vector::ShapeCastOp>(loc, newVType, operand);
1712 newOperands.push_back(opSC);
1715 VectorType newResultVectorType =
1716 dropNonScalableUnitDimFromType(resultVectorType);
1720 newResultVectorType, op->
getAttrs());
1749 struct DropUnitDimsFromTransposeOp final
1753 LogicalResult matchAndRewrite(vector::TransposeOp op,
1755 VectorType sourceType = op.getSourceVectorType();
1756 VectorType sourceTypeWithoutUnitDims =
1757 dropNonScalableUnitDimFromType(sourceType);
1759 if (sourceType == sourceTypeWithoutUnitDims)
1765 int64_t droppedDims = 0;
1767 droppedDimsBefore[i] = droppedDims;
1768 if (dim == std::make_tuple(1,
false))
1775 for (int64_t idx : perm) {
1776 if (sourceDims[idx] == std::make_tuple(1,
false))
1778 newPerm.push_back(idx - droppedDimsBefore[idx]);
1784 if (newPerm.empty()) {
1785 newPerm.push_back(0);
1790 auto dropDimsShapeCast = rewriter.
create<vector::ShapeCastOp>(
1791 loc, sourceTypeWithoutUnitDims, op.getVector());
1793 auto transposeWithoutUnitDims =
1794 rewriter.
create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1797 op, op.getResultVectorType(), transposeWithoutUnitDims);
1830 LogicalResult matchAndRewrite(scf::ForOp forOp,
1834 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
1835 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1839 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1840 if (vectorType == newVectorType)
1845 return b.
create<vector::ShapeCastOp>(loc, type, source);
1849 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1852 replacement, castFn));
1875 LogicalResult matchAndRewrite(vector::ReductionOp op,
1878 if (op.getKind() != vector::CombiningKind::ADD)
1881 Type elemType = op.getSourceVectorType().getElementType();
1884 if (!isa<FloatType>(elemType))
1897 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
1915 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
1917 unsigned maxNumElementsToExtract,
1920 maxNumElementsToExtract(maxNumElementsToExtract) {}
1922 LogicalResult matchAndRewrite(vector::ReductionOp op,
1924 VectorType type = op.getSourceVectorType();
1925 if (type.isScalable() || op.isMasked())
1927 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
1929 int64_t numElems = type.getNumElements();
1930 if (numElems > maxNumElementsToExtract) {
1932 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
1933 "(max allowed: {1})",
1934 numElems, maxNumElementsToExtract));
1940 extractedElem = rewriter.
create<vector::ExtractOp>(
1941 loc, op.getVector(),
static_cast<int64_t
>(idx));
1943 Value res = extracted.front();
1944 for (
auto extractedElem : llvm::drop_begin(extracted))
1946 extractedElem, op.getFastmathAttr());
1947 if (
Value acc = op.getAcc())
1949 op.getFastmathAttr());
1956 unsigned maxNumElementsToExtract = 0;
1975 template <
typename MulOpType>
1976 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
1980 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
1983 if (!broadcastOp.computeBroadcastedUnitDims().empty())
1986 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1987 return srcType && srcType.getRank() != 2;
1990 LogicalResult matchAndRewrite(MulOpType mulOp,
1992 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1995 if (resType.getRank() != 2)
2000 auto matchOuterProduct =
2002 Value operandB) -> FailureOr<vector::OuterProductOp> {
2003 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2008 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2011 auto broadcastedLhs =
2012 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2013 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2016 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2017 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2020 return rewriter.
create<vector::OuterProductOp>(
2021 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2022 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2025 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2026 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2028 if (failed(maybeOuterP))
2029 maybeOuterP = matchOuterProduct(rhs, lhs);
2030 if (failed(maybeOuterP))
2032 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2041 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2042 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2049 patterns.add<VectorCreateMaskOpConversion,
2050 MaterializeTransferMask<vector::TransferReadOp>,
2051 MaterializeTransferMask<vector::TransferWriteOp>>(
2052 patterns.getContext(), force32BitVectorIndices, benefit);
2068 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2069 DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2075 patterns.add<BubbleDownVectorBitCastForExtract,
2076 BubbleDownBitCastForStridedSliceExtract,
2077 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2083 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2085 std::move(controlFn), benefit);
2090 std::function<LogicalResult(vector::ContractionOp)> constraint,
2092 patterns.add<CanonicalizeContractMatmulToMMT>(
patterns.getContext(), benefit,
2093 std::move(constraint));
2098 patterns.add<MultiReduceToContract, CombineContractBroadcast,
2099 CombineContractABTranspose, CombineContractResultTranspose>(
2106 patterns.add<DropInnerMostUnitDimsTransferRead,
2107 DropInnerMostUnitDimsTransferWrite>(
patterns.getContext(),
2113 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2114 ReorderElementwiseOpsOnBroadcast>(
patterns.getContext(),
2129 maxNumElementsToExtract, benefit);
2134 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2135 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2143 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...