19 #include <type_traits>
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "vector-to-vector"
55 template <
typename IntType>
57 return llvm::to_vector<4>(llvm::map_range(
58 arrayAttr.getAsRange<IntegerAttr>(),
59 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
93 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
96 auto sourceVectorType =
97 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
98 auto resultVectorType =
99 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
100 if (!sourceVectorType || !resultVectorType)
104 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105 shapeCastOp.getSource().getDefiningOp());
106 if (!sourceShapeCastOp)
108 auto operandSourceVectorType =
109 cast<VectorType>(sourceShapeCastOp.getSource().getType());
110 auto operandResultVectorType = sourceShapeCastOp.getType();
113 if (operandSourceVectorType != resultVectorType ||
114 operandResultVectorType != sourceVectorType)
117 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
139 struct MultiReduceToContract
143 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
145 if (reduceOp.getKind() != vector::CombiningKind::ADD)
147 Operation *mulOp = reduceOp.getSource().getDefiningOp();
148 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
155 if (!isReduceDim.value()) {
156 iteratorTypes.push_back(vector::IteratorType::parallel);
159 iteratorTypes.push_back(vector::IteratorType::reduction);
164 0, exprs, reduceOp.getContext());
170 return IteratorTypeAttr::get(rewriter.getContext(), t);
199 struct CombineContractABTranspose final
203 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
206 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
207 Value lhs = contractOp.getLhs();
208 Value rhs = contractOp.getRhs();
210 bool changed =
false;
211 for (
Value *operand : {&lhs, &rhs}) {
213 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
217 transposeOp.getPermutation(), contractOp.getContext());
219 *operand = transposeOp.getVector();
225 contractOp, lhs, rhs, contractOp.getAcc(),
263 struct CombineContractResultTranspose final
267 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
269 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
270 if (!contractOp || !contractOp->hasOneUse())
273 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
278 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
290 auto combinedResMap = resTMap.compose(contractMap);
297 maps.back() = combinedResMap;
300 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
328 struct CombineContractBroadcast
332 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
335 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
336 Value lhs = contractOp.getLhs();
337 Value rhs = contractOp.getRhs();
339 bool changed =
false;
340 for (
Value *operand : {&lhs, &rhs}) {
346 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
348 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
351 broadcast.getResultVectorType().getRank() - srcType.getRank();
352 bool innerDimBroadcast =
false;
355 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
356 rankDiff + dim.index())) {
357 innerDimBroadcast =
true;
360 originalDims.push_back(
365 if (innerDimBroadcast)
370 bool nonUnitDimReductionBroadcast =
false;
371 for (int64_t i = 0; i < rankDiff; ++i) {
372 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
375 nonUnitDimReductionBroadcast =
true;
379 if (nonUnitDimReductionBroadcast)
385 map = broadcastMap.
compose(map);
401 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
402 if (!unusedDimsBitVector.test(i))
403 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
410 bool hasReductionIteratorApplyingOnBothSides =
false;
411 for (
unsigned i = 0; i < iterators.size(); ++i) {
415 hasReductionIteratorApplyingOnBothSides =
true;
419 if (!hasReductionIteratorApplyingOnBothSides)
427 contractOp, lhs, rhs, contractOp.getAcc(),
446 struct ReorderCastOpsOnBroadcast
450 LogicalResult matchAndRewrite(CastOpInterface op,
459 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
460 castResTy = vecTy.clone(castResTy);
463 bcastOp.getSource(), castResTy, op->
getAttrs());
484 struct ReorderElementwiseOpsOnTranspose final
487 LogicalResult matchAndRewrite(
Operation *op,
500 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
502 transposeMaps.push_back(transposeOp.getPermutation());
503 srcType = transposeOp.getSourceVectorType();
508 if (transposeMaps.empty())
513 if (!llvm::all_equal(transposeMaps))
521 auto order = transposeMaps.front();
523 for (
int i = 0, e = order.size(); i < e; ++i)
524 invOrder[order[i]] = i;
527 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
529 srcValues.push_back(transposeOp.getVector());
533 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
534 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
535 operand.getLoc(), vectorType, operand, invOrder));
539 auto vectorType = srcType.clone(
546 transposeMaps.front());
553 return llvm::to_vector<4>(
554 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
555 [](IntegerAttr attr) { return attr.getInt(); }));
567 struct BubbleDownVectorBitCastForExtract
571 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574 if (extractOp.getSourceVectorType().getRank() != 1)
577 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
581 VectorType castSrcType = castOp.getSourceVectorType();
582 VectorType castDstType = castOp.getResultVectorType();
583 assert(castSrcType.getRank() == castDstType.getRank());
588 if (castSrcType.getNumElements() == 1)
593 if (castSrcType.getNumElements() > castDstType.getNumElements())
596 unsigned expandRatio =
597 castDstType.getNumElements() / castSrcType.getNumElements();
600 assert(values[0].is<Attribute>() &&
"Unexpected non-constant index");
601 return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
609 Value packedValue = rewriter.
create<vector::ExtractOp>(
610 loc, castOp.getSource(), index / expandRatio);
613 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
614 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
619 VectorType packedType =
622 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
626 index % expandRatio);
643 struct BubbleDownBitCastForStridedSliceExtract
647 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
653 VectorType castSrcType = castOp.getSourceVectorType();
654 VectorType castDstType = castOp.getResultVectorType();
655 assert(castSrcType.getRank() == castDstType.getRank());
657 int64_t castSrcLastDim = castSrcType.getShape().back();
658 int64_t castDstLastDim = castDstType.getShape().back();
660 if (castSrcLastDim > castDstLastDim)
664 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
665 [](
const APInt &val) { return !val.isOne(); }))
668 unsigned rank = extractOp.getSourceVectorType().getRank();
669 assert(castDstLastDim % castSrcLastDim == 0);
670 int64_t expandRatio = castDstLastDim / castSrcLastDim;
676 ArrayAttr newOffsets = extractOp.getOffsets();
677 if (newOffsets.size() == rank) {
679 if (offsets.back() % expandRatio != 0)
681 offsets.back() = offsets.back() / expandRatio;
686 ArrayAttr newSizes = extractOp.getSizes();
687 if (newSizes.size() == rank) {
689 if (sizes.back() % expandRatio != 0)
691 sizes.back() = sizes.back() / expandRatio;
696 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
697 dims.back() = dims.back() / expandRatio;
698 VectorType newExtractType =
701 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
702 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
703 newSizes, extractOp.getStrides());
706 extractOp, extractOp.getType(), newExtractOp);
722 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
725 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727 VectorType castSrcType = bitcastOp.getSourceVectorType();
728 VectorType castDstType = bitcastOp.getResultVectorType();
731 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
732 castDstType.isScalable())
735 int64_t castSrcLastDim = castSrcType.getShape().back();
736 int64_t castDstLastDim = castDstType.getShape().back();
737 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
739 if (isNumElemsShrink) {
740 assert(castSrcLastDim % castDstLastDim == 0);
741 ratio = castSrcLastDim / castDstLastDim;
743 assert(castDstLastDim % castSrcLastDim == 0);
744 ratio = castDstLastDim / castSrcLastDim;
747 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
752 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
759 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
760 VectorType newCastSrcType =
762 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
763 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
767 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
768 VectorType newCastDstType =
772 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
773 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
777 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
793 struct BubbleUpBitCastForStridedSliceInsert
797 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
799 VectorType castSrcType = bitcastOp.getSourceVectorType();
800 VectorType castDstType = bitcastOp.getResultVectorType();
801 assert(castSrcType.getRank() == castDstType.getRank());
803 if (castSrcType.getRank() == 0)
806 int64_t castSrcLastDim = castSrcType.getShape().back();
807 int64_t castDstLastDim = castDstType.getShape().back();
809 if (castSrcLastDim < castDstLastDim)
812 assert(castSrcLastDim % castDstLastDim == 0);
813 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
816 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
821 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
822 [](
const APInt &val) { return !val.isOne(); }))
825 unsigned rank = insertOp.getSourceVectorType().getRank();
828 if (rank != insertOp.getDestVectorType().getRank())
832 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
833 unsigned destinationWidth =
834 castDstType.getElementType().getIntOrFloatBitWidth();
835 unsigned numElements = destinationWidth / sourceWidth;
836 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
839 ArrayAttr newOffsets = insertOp.getOffsets();
840 assert(newOffsets.size() == rank);
842 if (offsets.back() % shrinkRatio != 0)
844 offsets.back() = offsets.back() / shrinkRatio;
848 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
849 srcDims.back() = srcDims.back() / shrinkRatio;
850 VectorType newCastSrcType =
853 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
854 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
857 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
858 dstDims.back() = dstDims.back() / shrinkRatio;
859 VectorType newCastDstType =
862 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
863 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
866 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
867 insertOp.getStrides());
891 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
896 std::function<
bool(vector::BitCastOp)> controlFn,
900 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
903 if (controlFn && !controlFn(bitcastOp))
906 VectorType castSrcType = bitcastOp.getSourceVectorType();
907 VectorType castDstType = bitcastOp.getResultVectorType();
908 assert(castSrcType.getRank() == castDstType.getRank());
911 if (castSrcType.getRank() != 1)
914 int64_t castSrcLastDim = castSrcType.getShape().back();
915 int64_t castDstLastDim = castDstType.getShape().back();
917 if (castSrcLastDim < castDstLastDim)
920 assert(castSrcLastDim % castDstLastDim == 0);
921 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
923 if (castSrcLastDim == shrinkRatio)
927 Type elemType = castDstType.getElementType();
932 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
936 VectorType newCastDstType =
938 castDstType.getElementType());
940 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
941 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
943 sliceShape, strides);
945 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
946 res = rewriter.
create<InsertStridedSliceOp>(
955 std::function<bool(BitCastOp)> controlFn;
972 struct ReorderElementwiseOpsOnBroadcast final
975 LogicalResult matchAndRewrite(
Operation *op,
983 op,
"Op doesn't have ElementwiseMappableTraits");
988 "result and operand type mismatch");
989 if (isa<vector::FMAOp>(op)) {
992 "Op only accepts vector types - not supported as broadcast source "
993 "might be a scalar");
998 if (!lhsBcastOrSplat ||
999 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1001 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1008 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1010 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1011 auto splat = val.getDefiningOp<vector::SplatOp>();
1013 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1023 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1029 lhsBcastOrSplatType, op->
getAttrs());
1034 op, vectorType, elementwiseOp->
getResults());
1050 bool force32BitVectorIndices, int64_t dim,
1059 if (dim == 0 && force32BitVectorIndices) {
1062 }
else if (dim == 0) {
1065 }
else if (force32BitVectorIndices) {
1067 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1070 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1072 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1077 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1082 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1083 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1087 template <
typename ConcreteOp>
1090 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1093 force32BitVectorIndices(enableIndexOpt) {}
1095 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1097 if (!xferOp.hasOutOfBoundsDim())
1100 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1104 VectorType vtp = xferOp.getVectorType();
1111 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1112 Value off = xferOp.getIndices()[lastIndex];
1116 Value mask = rewriter.
create<vector::CreateMaskOp>(
1119 vtp.getScalableDims()),
1121 if (xferOp.getMask()) {
1123 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1127 xferOp.getMaskMutable().assign(mask);
1135 const bool force32BitVectorIndices;
1139 class VectorCreateMaskOpConversion
1142 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1143 bool enableIndexOpt,
1146 force32BitVectorIndices(enableIndexOpt) {}
1148 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1150 auto dstType = op.getType();
1151 if (cast<VectorType>(dstType).isScalable())
1153 int64_t rank = dstType.getRank();
1157 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1158 rank == 0 ? 0 : dstType.getDimSize(0),
1164 const bool force32BitVectorIndices;
1168 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1169 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1174 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1175 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1192 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1194 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1195 if (!vecType || !vecType.getElementType().isInteger(1))
1199 Value cond = selectOp.getCondition();
1200 if (isa<VectorType>(cond.
getType()))
1204 if (vecType.getRank() != 1 || vecType.isScalable())
1208 if (vecType.getShape()[0] != 1)
1211 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1212 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1216 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1217 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1221 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1243 static FailureOr<size_t>
1244 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1250 auto isUnitDim = [](VectorType type,
int dim) {
1251 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1258 int rankDiff = srcType.getRank() - vectorType.getRank();
1259 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1262 int dim = vectorType.getRank() - i - 1;
1263 if (srcStrides[dim + rankDiff] != 1 ||
1264 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1272 class DropInnerMostUnitDimsTransferRead
1276 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1279 if (readOp.getTransferRank() == 0)
1283 if (readOp.getMask())
1286 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1290 if (!readOp.getPermutationMap().isMinorIdentity())
1293 auto targetType = readOp.getVectorType();
1294 if (targetType.getRank() <= 1)
1297 FailureOr<size_t> maybeDimsToDrop =
1298 getTransferFoldableInnerUnitDims(srcType, targetType);
1299 if (failed(maybeDimsToDrop))
1302 size_t dimsToDrop = maybeDimsToDrop.value();
1303 if (dimsToDrop == 0)
1306 auto inBounds = readOp.getInBoundsValues();
1307 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1308 if (llvm::is_contained(droppedInBounds,
false))
1311 auto resultTargetVecType =
1313 targetType.getElementType(),
1314 targetType.getScalableDims().drop_back(dimsToDrop));
1316 auto loc = readOp.getLoc();
1323 auto resultMemrefType =
1324 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1325 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1328 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1329 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1330 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1332 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1333 Value result = rewriter.
create<vector::TransferReadOp>(
1334 loc, resultTargetVecType, rankedReducedView,
1336 readOp.getPadding(),
1338 Value(), inBoundsAttr);
1363 class DropInnerMostUnitDimsTransferWrite
1367 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1370 if (writeOp.getTransferRank() == 0)
1374 if (writeOp.getMask())
1377 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1381 if (!writeOp.getPermutationMap().isMinorIdentity())
1384 auto targetType = writeOp.getVectorType();
1385 if (targetType.getRank() <= 1)
1388 FailureOr<size_t> maybeDimsToDrop =
1389 getTransferFoldableInnerUnitDims(srcType, targetType);
1390 if (failed(maybeDimsToDrop))
1393 size_t dimsToDrop = maybeDimsToDrop.value();
1394 if (dimsToDrop == 0)
1397 auto inBounds = writeOp.getInBoundsValues();
1398 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1399 if (llvm::is_contained(droppedInBounds,
false))
1402 auto resultTargetVecType =
1404 targetType.getElementType(),
1405 targetType.getScalableDims().drop_back(dimsToDrop));
1414 auto resultMemrefType =
1415 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1416 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1419 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1421 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1422 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1424 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1426 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1427 loc, resultTargetVecType, writeOp.getVector());
1429 writeOp, shapeCast, rankedReducedView,
1432 Value(), inBoundsAttr);
1440 struct CanonicalizeContractMatmulToMMT final
1444 using FilterConstraintType =
1445 std::function<LogicalResult(vector::ContractionOp op)>;
1448 FilterConstraintType constraint)
1450 filter(std::move(constraint)) {}
1452 LogicalResult matchAndRewrite(vector::ContractionOp op,
1454 if (failed(filter(op)))
1458 Value lhs = op.getLhs();
1459 Value rhs = op.getRhs();
1460 Value res = op.getAcc();
1464 auto infer = [&](MapList m) {
1471 static constexpr std::array<int64_t, 2> perm = {1, 0};
1472 auto iteratorTypes = op.getIteratorTypes().getValue();
1474 if (iteratorTypes.size() != 3 ||
1481 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1482 if (maps == canonicalForm)
1487 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1488 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1490 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1491 VectorType newType =
1492 cast<VectorType>(trans.
getType())
1493 .
clone(cast<VectorType>(mat.getType()).getElementType());
1494 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1496 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1498 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1499 VectorType newType =
1501 cast<VectorType>(mat.getType()).getElementType());
1502 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1504 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1507 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1508 rhs = createTranspose(rhs);
1509 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1510 lhs = createTranspose(lhs);
1511 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1512 rhs = createTranspose(rhs);
1513 lhs = createTranspose(lhs);
1514 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1515 std::swap(rhs, lhs);
1516 rhs = createTranspose(rhs);
1517 lhs = createTranspose(lhs);
1518 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1519 std::swap(rhs, lhs);
1520 rhs = createTranspose(rhs);
1521 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1522 std::swap(lhs, rhs);
1523 lhs = createTranspose(lhs);
1524 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1525 std::swap(lhs, rhs);
1531 op.getIteratorTypes());
1536 FilterConstraintType filter;
1556 template <
typename ExtOp>
1557 struct FoldArithExtIntoContractionOp
1561 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1564 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1565 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1567 if (!lhsDefOp || !rhsDefOp) {
1569 "no defining op on contract operands");
1573 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1574 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1575 contractOp.getIteratorTypesAttr());
1594 LogicalResult matchAndRewrite(vector::ReductionOp op,
1597 if (op.getKind() != vector::CombiningKind::ADD)
1601 Value acc = op.getAcc();
1608 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1609 if (!parentReduction)
1614 if (isa<IntegerType>(acc.
getType())) {
1616 loc, parentReduction.getVector(), op.getVector());
1618 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1622 parentReduction.getAcc());
1632 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1633 auto inVecShape = inVecTy.getShape();
1636 for (
auto [dim, isScalable] :
1637 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1638 if (dim == 1 && !isScalable)
1641 newShape.push_back(dim);
1642 newScalableDims.push_back(isScalable);
1645 if (newShape.empty()) {
1646 newShape.push_back(1);
1647 newScalableDims.push_back(
false);
1650 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1678 struct DropUnitDimFromElementwiseOps final
1681 LogicalResult matchAndRewrite(
Operation *op,
1687 if (!resultVectorType)
1694 if (!sourceVectorType)
1696 if (sourceVectorType.getRank() < 2)
1702 auto opVectorType = cast<VectorType>(operand.getType());
1703 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1704 if (newVType == opVectorType)
1707 auto opSC = rewriter.
create<vector::ShapeCastOp>(loc, newVType, operand);
1708 newOperands.push_back(opSC);
1711 VectorType newResultVectorType =
1712 dropNonScalableUnitDimFromType(resultVectorType);
1716 newResultVectorType, op->
getAttrs());
1745 struct DropUnitDimsFromTransposeOp final
1749 LogicalResult matchAndRewrite(vector::TransposeOp op,
1751 VectorType sourceType = op.getSourceVectorType();
1752 VectorType sourceTypeWithoutUnitDims =
1753 dropNonScalableUnitDimFromType(sourceType);
1755 if (sourceType == sourceTypeWithoutUnitDims)
1761 int64_t droppedDims = 0;
1763 droppedDimsBefore[i] = droppedDims;
1764 if (dim == std::make_tuple(1,
false))
1771 for (int64_t idx : perm) {
1772 if (sourceDims[idx] == std::make_tuple(1,
false))
1774 newPerm.push_back(idx - droppedDimsBefore[idx]);
1780 if (newPerm.empty()) {
1781 newPerm.push_back(0);
1786 auto dropDimsShapeCast = rewriter.
create<vector::ShapeCastOp>(
1787 loc, sourceTypeWithoutUnitDims, op.getVector());
1789 auto tranposeWithoutUnitDims =
1790 rewriter.
create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1793 op, op.getResultVectorType(), tranposeWithoutUnitDims);
1826 LogicalResult matchAndRewrite(scf::ForOp forOp,
1830 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
1831 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1835 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1836 if (vectorType == newVectorType)
1841 return b.
create<vector::ShapeCastOp>(loc, type, source);
1845 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
1848 replacement, castFn));
1871 LogicalResult matchAndRewrite(vector::ReductionOp op,
1874 if (op.getKind() != vector::CombiningKind::ADD)
1877 Type elemType = op.getSourceVectorType().getElementType();
1880 if (!isa<FloatType>(elemType))
1883 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1893 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
1911 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
1913 unsigned maxNumElementsToExtract,
1916 maxNumElementsToExtract(maxNumElementsToExtract) {}
1918 LogicalResult matchAndRewrite(vector::ReductionOp op,
1920 VectorType type = op.getSourceVectorType();
1921 if (type.isScalable() || op.isMasked())
1923 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
1925 int64_t numElems = type.getNumElements();
1926 if (numElems > maxNumElementsToExtract) {
1928 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
1929 "(max allowed: {1})",
1930 numElems, maxNumElementsToExtract));
1936 extractedElem = rewriter.
create<vector::ExtractOp>(
1937 loc, op.getVector(),
static_cast<int64_t
>(idx));
1939 Value res = extracted.front();
1940 for (
auto extractedElem : llvm::drop_begin(extracted))
1942 extractedElem, op.getFastmathAttr());
1943 if (
Value acc = op.getAcc())
1945 op.getFastmathAttr());
1952 unsigned maxNumElementsToExtract = 0;
1971 template <
typename MulOpType>
1972 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
1976 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
1979 if (!broadcastOp.computeBroadcastedUnitDims().empty())
1982 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1983 return srcType && srcType.getRank() != 2;
1986 LogicalResult matchAndRewrite(MulOpType mulOp,
1988 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1991 if (resType.getRank() != 2)
1996 auto matchOuterProduct =
1998 Value operandB) -> FailureOr<vector::OuterProductOp> {
1999 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2004 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2007 auto broadcastedLhs =
2008 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2009 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2012 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2013 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2016 return rewriter.
create<vector::OuterProductOp>(
2017 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2018 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2021 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2022 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2024 if (failed(maybeOuterP))
2025 maybeOuterP = matchOuterProduct(rhs, lhs);
2026 if (failed(maybeOuterP))
2028 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2037 patterns.
add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2038 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2045 patterns.
add<VectorCreateMaskOpConversion,
2046 MaterializeTransferMask<vector::TransferReadOp>,
2047 MaterializeTransferMask<vector::TransferWriteOp>>(
2048 patterns.
getContext(), force32BitVectorIndices, benefit);
2054 patterns.
add<ShapeCastOpFolder>(patterns.
getContext(), benefit);
2064 patterns.
add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2065 DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2071 patterns.
add<BubbleDownVectorBitCastForExtract,
2072 BubbleDownBitCastForStridedSliceExtract,
2073 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2079 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2081 std::move(controlFn), benefit);
2086 std::function<LogicalResult(vector::ContractionOp)> constraint,
2088 patterns.
add<CanonicalizeContractMatmulToMMT>(patterns.
getContext(), benefit,
2089 std::move(constraint));
2094 patterns.
add<MultiReduceToContract, CombineContractBroadcast,
2095 CombineContractABTranspose, CombineContractResultTranspose>(
2102 patterns.
add<DropInnerMostUnitDimsTransferRead,
2103 DropInnerMostUnitDimsTransferWrite>(patterns.
getContext(),
2109 patterns.
add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2110 ReorderElementwiseOpsOnBroadcast>(patterns.
getContext(),
2116 patterns.
add<ChainedReduction>(patterns.
getContext(), benefit);
2124 patterns.
add<BreakDownVectorReduction>(patterns.
getContext(),
2125 maxNumElementsToExtract, benefit);
2130 patterns.
add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2131 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2139 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...