19 #include <type_traits>
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "vector-to-vector"
55 template <
typename IntType>
57 return llvm::to_vector<4>(llvm::map_range(
58 arrayAttr.getAsRange<IntegerAttr>(),
59 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
93 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
96 auto sourceVectorType =
97 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
98 auto resultVectorType =
99 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
100 if (!sourceVectorType || !resultVectorType)
104 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105 shapeCastOp.getSource().getDefiningOp());
106 if (!sourceShapeCastOp)
108 auto operandSourceVectorType =
109 cast<VectorType>(sourceShapeCastOp.getSource().getType());
110 auto operandResultVectorType = sourceShapeCastOp.getType();
113 if (operandSourceVectorType != resultVectorType ||
114 operandResultVectorType != sourceVectorType)
117 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
139 struct MultiReduceToContract
143 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
145 if (reduceOp.getKind() != vector::CombiningKind::ADD)
147 Operation *mulOp = reduceOp.getSource().getDefiningOp();
148 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
155 if (!isReduceDim.value()) {
156 iteratorTypes.push_back(vector::IteratorType::parallel);
159 iteratorTypes.push_back(vector::IteratorType::reduction);
164 0, exprs, reduceOp.getContext());
170 return IteratorTypeAttr::get(rewriter.getContext(), t);
199 struct CombineContractABTranspose final
203 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
206 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
207 Value lhs = contractOp.getLhs();
208 Value rhs = contractOp.getRhs();
211 for (
Value *operand : {&lhs, &rhs}) {
213 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
217 transposeOp.getPermutation(), contractOp.getContext());
219 *operand = transposeOp.getVector();
225 contractOp, lhs, rhs, contractOp.getAcc(),
263 struct CombineContractResultTranspose final
267 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
269 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
270 if (!contractOp || !contractOp->hasOneUse())
273 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
278 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
290 auto combinedResMap = resTMap.compose(contractMap);
297 maps.back() = combinedResMap;
300 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
328 struct CombineContractBroadcast
332 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
335 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
336 Value lhs = contractOp.getLhs();
337 Value rhs = contractOp.getRhs();
340 for (
Value *operand : {&lhs, &rhs}) {
346 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
348 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
351 broadcast.getResultVectorType().getRank() - srcType.getRank();
352 bool innerDimBroadcast =
false;
355 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
356 rankDiff + dim.index())) {
357 innerDimBroadcast =
true;
360 originalDims.push_back(
365 if (innerDimBroadcast)
370 bool nonUnitDimReductionBroadcast =
false;
371 for (int64_t i = 0; i < rankDiff; ++i) {
372 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
375 nonUnitDimReductionBroadcast =
true;
379 if (nonUnitDimReductionBroadcast)
385 map = broadcastMap.
compose(map);
401 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
402 if (!unusedDimsBitVector.test(i))
403 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
410 bool hasReductionIteratorApplyingOnBothSides =
false;
411 for (
unsigned i = 0; i < iterators.size(); ++i) {
415 hasReductionIteratorApplyingOnBothSides =
true;
419 if (!hasReductionIteratorApplyingOnBothSides)
427 contractOp, lhs, rhs, contractOp.getAcc(),
446 struct ReorderCastOpsOnBroadcast
450 LogicalResult matchAndRewrite(CastOpInterface op,
452 if (op->getNumOperands() != 1)
454 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
459 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
460 castResTy = vecTy.clone(castResTy);
462 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
463 bcastOp.getSource(), castResTy, op->getAttrs());
465 op, op->getResult(0).getType(), castOp->getResult(0));
484 struct ReorderElementwiseOpsOnTranspose final
487 LogicalResult matchAndRewrite(
Operation *op,
500 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
502 transposeMaps.push_back(transposeOp.getPermutation());
503 srcType = transposeOp.getSourceVectorType();
508 if (transposeMaps.empty())
513 if (!llvm::all_equal(transposeMaps))
521 auto order = transposeMaps.front();
523 for (
int i = 0, e = order.size(); i < e; ++i)
524 invOrder[order[i]] = i;
527 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
529 srcValues.push_back(transposeOp.getVector());
533 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
534 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
535 operand.getLoc(), vectorType, operand, invOrder));
539 auto vectorType = srcType.clone(
546 transposeMaps.front());
553 return llvm::to_vector<4>(
554 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
555 [](IntegerAttr attr) { return attr.getInt(); }));
567 struct BubbleDownVectorBitCastForExtract
571 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574 if (extractOp.getSourceVectorType().getRank() != 1)
577 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
581 VectorType castSrcType = castOp.getSourceVectorType();
582 VectorType castDstType = castOp.getResultVectorType();
583 assert(castSrcType.getRank() == castDstType.getRank());
588 if (castSrcType.getNumElements() == 1)
593 if (castSrcType.getNumElements() > castDstType.getNumElements())
596 unsigned expandRatio =
597 castDstType.getNumElements() / castSrcType.getNumElements();
600 auto mixedPos = extractOp.getMixedPosition();
601 if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
603 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
608 Value packedValue = rewriter.
create<vector::ExtractOp>(
609 loc, castOp.getSource(), index / expandRatio);
612 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
613 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
618 VectorType packedType =
621 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
625 index % expandRatio);
642 struct BubbleDownBitCastForStridedSliceExtract
646 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
648 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
652 VectorType castSrcType = castOp.getSourceVectorType();
653 VectorType castDstType = castOp.getResultVectorType();
654 assert(castSrcType.getRank() == castDstType.getRank());
656 int64_t castSrcLastDim = castSrcType.getShape().back();
657 int64_t castDstLastDim = castDstType.getShape().back();
659 if (castSrcLastDim > castDstLastDim)
663 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
664 [](
const APInt &val) { return !val.isOne(); }))
667 unsigned rank = extractOp.getSourceVectorType().getRank();
668 assert(castDstLastDim % castSrcLastDim == 0);
669 int64_t expandRatio = castDstLastDim / castSrcLastDim;
675 ArrayAttr newOffsets = extractOp.getOffsets();
676 if (newOffsets.size() == rank) {
678 if (offsets.back() % expandRatio != 0)
680 offsets.back() = offsets.back() / expandRatio;
685 ArrayAttr newSizes = extractOp.getSizes();
686 if (newSizes.size() == rank) {
688 if (sizes.back() % expandRatio != 0)
690 sizes.back() = sizes.back() / expandRatio;
695 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
696 dims.back() = dims.back() / expandRatio;
697 VectorType newExtractType =
700 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
701 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
702 newSizes, extractOp.getStrides());
705 extractOp, extractOp.getType(), newExtractOp);
721 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
724 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
726 VectorType castSrcType = bitcastOp.getSourceVectorType();
727 VectorType castDstType = bitcastOp.getResultVectorType();
730 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
731 castDstType.isScalable())
734 int64_t castSrcLastDim = castSrcType.getShape().back();
735 int64_t castDstLastDim = castDstType.getShape().back();
736 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
738 if (isNumElemsShrink) {
739 assert(castSrcLastDim % castDstLastDim == 0);
740 ratio = castSrcLastDim / castDstLastDim;
742 assert(castDstLastDim % castSrcLastDim == 0);
743 ratio = castDstLastDim / castSrcLastDim;
746 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
751 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
758 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
759 VectorType newCastSrcType =
761 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
762 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
766 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
767 VectorType newCastDstType =
771 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
772 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
776 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
792 struct BubbleUpBitCastForStridedSliceInsert
796 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
798 VectorType castSrcType = bitcastOp.getSourceVectorType();
799 VectorType castDstType = bitcastOp.getResultVectorType();
800 assert(castSrcType.getRank() == castDstType.getRank());
802 if (castSrcType.getRank() == 0)
805 int64_t castSrcLastDim = castSrcType.getShape().back();
806 int64_t castDstLastDim = castDstType.getShape().back();
808 if (castSrcLastDim < castDstLastDim)
811 assert(castSrcLastDim % castDstLastDim == 0);
812 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
815 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
820 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
821 [](
const APInt &val) { return !val.isOne(); }))
824 unsigned rank = insertOp.getSourceVectorType().getRank();
827 if (rank != insertOp.getDestVectorType().getRank())
831 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
832 unsigned destinationWidth =
833 castDstType.getElementType().getIntOrFloatBitWidth();
834 unsigned numElements = destinationWidth / sourceWidth;
835 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
838 ArrayAttr newOffsets = insertOp.getOffsets();
839 assert(newOffsets.size() == rank);
841 if (offsets.back() % shrinkRatio != 0)
843 offsets.back() = offsets.back() / shrinkRatio;
847 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
848 srcDims.back() = srcDims.back() / shrinkRatio;
849 VectorType newCastSrcType =
852 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
853 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
856 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
857 dstDims.back() = dstDims.back() / shrinkRatio;
858 VectorType newCastDstType =
861 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
862 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
865 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
866 insertOp.getStrides());
890 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
895 std::function<
bool(vector::BitCastOp)> controlFn,
899 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
902 if (controlFn && !controlFn(bitcastOp))
905 VectorType castSrcType = bitcastOp.getSourceVectorType();
906 VectorType castDstType = bitcastOp.getResultVectorType();
907 assert(castSrcType.getRank() == castDstType.getRank());
910 if (castSrcType.getRank() != 1)
913 int64_t castSrcLastDim = castSrcType.getShape().back();
914 int64_t castDstLastDim = castDstType.getShape().back();
916 if (castSrcLastDim < castDstLastDim)
919 assert(castSrcLastDim % castDstLastDim == 0);
920 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
922 if (castSrcLastDim == shrinkRatio)
926 Type elemType = castDstType.getElementType();
931 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
935 VectorType newCastDstType =
937 castDstType.getElementType());
939 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
940 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
942 sliceShape, strides);
944 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
945 res = rewriter.
create<InsertStridedSliceOp>(
954 std::function<bool(BitCastOp)> controlFn;
971 struct ReorderElementwiseOpsOnBroadcast final
974 LogicalResult matchAndRewrite(
Operation *op,
982 op,
"Op doesn't have ElementwiseMappableTraits");
987 "result and operand type mismatch");
988 if (isa<vector::FMAOp>(op)) {
991 "Op only accepts vector types - not supported as broadcast source "
992 "might be a scalar");
997 if (!lhsBcastOrSplat ||
998 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1000 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1007 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1009 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1010 auto splat = val.getDefiningOp<vector::SplatOp>();
1012 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1022 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1028 lhsBcastOrSplatType, op->
getAttrs());
1033 op, vectorType, elementwiseOp->
getResults());
1049 bool force32BitVectorIndices, int64_t dim,
1058 if (dim == 0 && force32BitVectorIndices) {
1061 }
else if (dim == 0) {
1064 }
else if (force32BitVectorIndices) {
1066 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1069 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1071 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1076 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1081 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1082 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1086 template <
typename ConcreteOp>
1089 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1092 force32BitVectorIndices(enableIndexOpt) {}
1094 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1096 if (!xferOp.hasOutOfBoundsDim())
1099 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1103 VectorType vtp = xferOp.getVectorType();
1110 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1111 Value off = xferOp.getIndices()[lastIndex];
1115 Value mask = rewriter.
create<vector::CreateMaskOp>(
1118 vtp.getScalableDims()),
1120 if (xferOp.getMask()) {
1122 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1126 xferOp.getMaskMutable().assign(mask);
1134 const bool force32BitVectorIndices;
1138 class VectorCreateMaskOpConversion
1141 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1142 bool enableIndexOpt,
1145 force32BitVectorIndices(enableIndexOpt) {}
1147 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1149 auto dstType = op.getType();
1150 if (cast<VectorType>(dstType).isScalable())
1152 int64_t rank = dstType.getRank();
1156 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1157 rank == 0 ? 0 : dstType.getDimSize(0),
1163 const bool force32BitVectorIndices;
1167 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1168 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1173 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1174 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1191 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1193 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1194 if (!vecType || !vecType.getElementType().isInteger(1))
1198 Value cond = selectOp.getCondition();
1199 if (isa<VectorType>(cond.
getType()))
1203 if (vecType.getRank() != 1 || vecType.isScalable())
1207 if (vecType.getShape()[0] != 1)
1210 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1211 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1215 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1216 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1220 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1242 static FailureOr<size_t>
1243 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1249 auto isUnitDim = [](VectorType type,
int dim) {
1250 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1257 int rankDiff = srcType.getRank() - vectorType.getRank();
1258 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1261 int dim = vectorType.getRank() - i - 1;
1262 if (srcStrides[dim + rankDiff] != 1 ||
1263 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1271 class DropInnerMostUnitDimsTransferRead
1275 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1278 if (readOp.getTransferRank() == 0)
1282 if (readOp.getMask())
1285 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1289 if (!readOp.getPermutationMap().isMinorIdentity())
1292 auto targetType = readOp.getVectorType();
1293 if (targetType.getRank() <= 1)
1296 FailureOr<size_t> maybeDimsToDrop =
1297 getTransferFoldableInnerUnitDims(srcType, targetType);
1298 if (failed(maybeDimsToDrop))
1301 size_t dimsToDrop = maybeDimsToDrop.value();
1302 if (dimsToDrop == 0)
1305 auto inBounds = readOp.getInBoundsValues();
1306 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1307 if (llvm::is_contained(droppedInBounds,
false))
1310 auto resultTargetVecType =
1312 targetType.getElementType(),
1313 targetType.getScalableDims().drop_back(dimsToDrop));
1315 auto loc = readOp.getLoc();
1322 auto resultMemrefType =
1323 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1324 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1327 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1328 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1329 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1331 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1332 Value result = rewriter.
create<vector::TransferReadOp>(
1333 loc, resultTargetVecType, rankedReducedView,
1335 readOp.getPadding(),
1337 Value(), inBoundsAttr);
1362 class DropInnerMostUnitDimsTransferWrite
1366 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1369 if (writeOp.getTransferRank() == 0)
1373 if (writeOp.getMask())
1376 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1380 if (!writeOp.getPermutationMap().isMinorIdentity())
1383 auto targetType = writeOp.getVectorType();
1384 if (targetType.getRank() <= 1)
1387 FailureOr<size_t> maybeDimsToDrop =
1388 getTransferFoldableInnerUnitDims(srcType, targetType);
1389 if (failed(maybeDimsToDrop))
1392 size_t dimsToDrop = maybeDimsToDrop.value();
1393 if (dimsToDrop == 0)
1396 auto inBounds = writeOp.getInBoundsValues();
1397 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1398 if (llvm::is_contained(droppedInBounds,
false))
1401 auto resultTargetVecType =
1403 targetType.getElementType(),
1404 targetType.getScalableDims().drop_back(dimsToDrop));
1413 auto resultMemrefType =
1414 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1415 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1418 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1420 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1421 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1423 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1425 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1426 loc, resultTargetVecType, writeOp.getVector());
1428 writeOp, shapeCast, rankedReducedView,
1431 Value(), inBoundsAttr);
1439 struct CanonicalizeContractMatmulToMMT final
1443 using FilterConstraintType =
1444 std::function<LogicalResult(vector::ContractionOp op)>;
1447 FilterConstraintType constraint)
1449 filter(std::move(constraint)) {}
1451 LogicalResult matchAndRewrite(vector::ContractionOp op,
1453 if (failed(filter(op)))
1457 Value lhs = op.getLhs();
1458 Value rhs = op.getRhs();
1459 Value res = op.getAcc();
1463 auto infer = [&](MapList m) {
1470 static constexpr std::array<int64_t, 2> perm = {1, 0};
1471 auto iteratorTypes = op.getIteratorTypes().getValue();
1473 if (iteratorTypes.size() != 3 ||
1480 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1481 if (maps == canonicalForm)
1486 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1487 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1489 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1490 VectorType newType =
1491 cast<VectorType>(trans.
getType())
1492 .
clone(cast<VectorType>(mat.getType()).getElementType());
1493 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1495 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1497 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1498 VectorType newType =
1500 cast<VectorType>(mat.getType()).getElementType());
1501 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1503 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1506 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1507 rhs = createTranspose(rhs);
1508 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1509 lhs = createTranspose(lhs);
1510 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1511 rhs = createTranspose(rhs);
1512 lhs = createTranspose(lhs);
1513 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1514 std::swap(rhs, lhs);
1515 rhs = createTranspose(rhs);
1516 lhs = createTranspose(lhs);
1517 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1518 std::swap(rhs, lhs);
1519 rhs = createTranspose(rhs);
1520 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1521 std::swap(lhs, rhs);
1522 lhs = createTranspose(lhs);
1523 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1524 std::swap(lhs, rhs);
1530 op.getIteratorTypes());
1535 FilterConstraintType filter;
1555 template <
typename ExtOp>
1556 struct FoldArithExtIntoContractionOp
1560 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1563 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1564 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1566 if (!lhsDefOp || !rhsDefOp) {
1568 "no defining op on contract operands");
1572 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1573 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1574 contractOp.getIteratorTypesAttr());
1593 LogicalResult matchAndRewrite(vector::ReductionOp op,
1596 if (op.getKind() != vector::CombiningKind::ADD)
1600 Value acc = op.getAcc();
1607 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1608 if (!parentReduction)
1613 if (isa<IntegerType>(acc.
getType())) {
1615 loc, parentReduction.getVector(), op.getVector());
1617 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1621 parentReduction.getAcc());
1631 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1632 auto inVecShape = inVecTy.getShape();
1635 for (
auto [dim, isScalable] :
1636 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1637 if (dim == 1 && !isScalable)
1640 newShape.push_back(dim);
1641 newScalableDims.push_back(isScalable);
1644 if (newShape.empty()) {
1645 newShape.push_back(1);
1646 newScalableDims.push_back(
false);
1649 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1677 struct DropUnitDimFromElementwiseOps final
1680 LogicalResult matchAndRewrite(
Operation *op,
1686 if (!resultVectorType)
1693 if (!sourceVectorType)
1695 if (sourceVectorType.getRank() < 2)
1701 auto opVectorType = cast<VectorType>(operand.getType());
1702 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1703 if (newVType == opVectorType)
1706 auto opSC = rewriter.
create<vector::ShapeCastOp>(loc, newVType, operand);
1707 newOperands.push_back(opSC);
1710 VectorType newResultVectorType =
1711 dropNonScalableUnitDimFromType(resultVectorType);
1715 newResultVectorType, op->
getAttrs());
1744 struct DropUnitDimsFromTransposeOp final
1748 LogicalResult matchAndRewrite(vector::TransposeOp op,
1750 VectorType sourceType = op.getSourceVectorType();
1751 VectorType sourceTypeWithoutUnitDims =
1752 dropNonScalableUnitDimFromType(sourceType);
1754 if (sourceType == sourceTypeWithoutUnitDims)
1760 int64_t droppedDims = 0;
1762 droppedDimsBefore[i] = droppedDims;
1763 if (dim == std::make_tuple(1,
false))
1770 for (int64_t idx : perm) {
1771 if (sourceDims[idx] == std::make_tuple(1,
false))
1773 newPerm.push_back(idx - droppedDimsBefore[idx]);
1779 if (newPerm.empty()) {
1780 newPerm.push_back(0);
1785 auto dropDimsShapeCast = rewriter.
create<vector::ShapeCastOp>(
1786 loc, sourceTypeWithoutUnitDims, op.getVector());
1788 auto tranposeWithoutUnitDims =
1789 rewriter.
create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1792 op, op.getResultVectorType(), tranposeWithoutUnitDims);
1825 LogicalResult matchAndRewrite(scf::ForOp forOp,
1829 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
1830 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1834 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1835 if (vectorType == newVectorType)
1840 return b.
create<vector::ShapeCastOp>(loc, type, source);
1844 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1847 replacement, castFn));
1870 LogicalResult matchAndRewrite(vector::ReductionOp op,
1873 if (op.getKind() != vector::CombiningKind::ADD)
1876 Type elemType = op.getSourceVectorType().getElementType();
1879 if (!isa<FloatType>(elemType))
1892 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
1910 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
1912 unsigned maxNumElementsToExtract,
1915 maxNumElementsToExtract(maxNumElementsToExtract) {}
1917 LogicalResult matchAndRewrite(vector::ReductionOp op,
1919 VectorType type = op.getSourceVectorType();
1920 if (type.isScalable() || op.isMasked())
1922 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
1924 int64_t numElems = type.getNumElements();
1925 if (numElems > maxNumElementsToExtract) {
1927 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
1928 "(max allowed: {1})",
1929 numElems, maxNumElementsToExtract));
1935 extractedElem = rewriter.
create<vector::ExtractOp>(
1936 loc, op.getVector(),
static_cast<int64_t
>(idx));
1938 Value res = extracted.front();
1939 for (
auto extractedElem : llvm::drop_begin(extracted))
1941 extractedElem, op.getFastmathAttr());
1942 if (
Value acc = op.getAcc())
1944 op.getFastmathAttr());
1951 unsigned maxNumElementsToExtract = 0;
1970 template <
typename MulOpType>
1971 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
1975 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
1978 if (!broadcastOp.computeBroadcastedUnitDims().empty())
1981 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1982 return srcType && srcType.getRank() != 2;
1985 LogicalResult matchAndRewrite(MulOpType mulOp,
1987 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1990 if (resType.getRank() != 2)
1995 auto matchOuterProduct =
1997 Value operandB) -> FailureOr<vector::OuterProductOp> {
1998 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2003 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2006 auto broadcastedLhs =
2007 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2008 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2011 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2012 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2015 return rewriter.
create<vector::OuterProductOp>(
2016 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2017 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2020 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2021 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2023 if (failed(maybeOuterP))
2024 maybeOuterP = matchOuterProduct(rhs, lhs);
2025 if (failed(maybeOuterP))
2027 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2036 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2037 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2044 patterns.add<VectorCreateMaskOpConversion,
2045 MaterializeTransferMask<vector::TransferReadOp>,
2046 MaterializeTransferMask<vector::TransferWriteOp>>(
2047 patterns.getContext(), force32BitVectorIndices, benefit);
2063 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2064 DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2070 patterns.add<BubbleDownVectorBitCastForExtract,
2071 BubbleDownBitCastForStridedSliceExtract,
2072 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2078 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2080 std::move(controlFn), benefit);
2085 std::function<LogicalResult(vector::ContractionOp)> constraint,
2087 patterns.add<CanonicalizeContractMatmulToMMT>(
patterns.getContext(), benefit,
2088 std::move(constraint));
2093 patterns.add<MultiReduceToContract, CombineContractBroadcast,
2094 CombineContractABTranspose, CombineContractResultTranspose>(
2101 patterns.add<DropInnerMostUnitDimsTransferRead,
2102 DropInnerMostUnitDimsTransferWrite>(
patterns.getContext(),
2108 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2109 ReorderElementwiseOpsOnBroadcast>(
patterns.getContext(),
2124 maxNumElementsToExtract, benefit);
2129 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2130 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2138 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...