19 #include <type_traits>
43 #include "llvm/ADT/DenseSet.h"
44 #include "llvm/ADT/MapVector.h"
45 #include "llvm/ADT/STLExtras.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include "llvm/Support/raw_ostream.h"
51 #define DEBUG_TYPE "vector-to-vector"
56 template <
typename IntType>
58 return llvm::to_vector<4>(llvm::map_range(
59 arrayAttr.getAsRange<IntegerAttr>(),
60 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
94 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
97 auto sourceVectorType =
98 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
99 auto resultVectorType =
100 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
101 if (!sourceVectorType || !resultVectorType)
105 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
106 shapeCastOp.getSource().getDefiningOp());
107 if (!sourceShapeCastOp)
109 auto operandSourceVectorType =
110 cast<VectorType>(sourceShapeCastOp.getSource().getType());
111 auto operandResultVectorType = sourceShapeCastOp.getType();
114 if (operandSourceVectorType != resultVectorType ||
115 operandResultVectorType != sourceVectorType)
118 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
140 struct MultiReduceToContract
144 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
146 if (reduceOp.getKind() != vector::CombiningKind::ADD)
148 Operation *mulOp = reduceOp.getSource().getDefiningOp();
149 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
156 if (!isReduceDim.value()) {
157 iteratorTypes.push_back(vector::IteratorType::parallel);
160 iteratorTypes.push_back(vector::IteratorType::reduction);
165 0, exprs, reduceOp.getContext());
171 return IteratorTypeAttr::get(rewriter.getContext(), t);
200 struct CombineContractABTranspose final
204 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
207 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
208 Value lhs = contractOp.getLhs();
209 Value rhs = contractOp.getRhs();
211 bool changed =
false;
212 for (
Value *operand : {&lhs, &rhs}) {
214 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
218 transposeOp.getPermutation(), contractOp.getContext());
220 *operand = transposeOp.getVector();
226 contractOp, lhs, rhs, contractOp.getAcc(),
264 struct CombineContractResultTranspose final
270 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
271 if (!contractOp || !contractOp->hasOneUse())
274 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
279 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
291 auto combinedResMap = resTMap.compose(contractMap);
298 maps.back() = combinedResMap;
301 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
329 struct CombineContractBroadcast
333 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
336 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
337 Value lhs = contractOp.getLhs();
338 Value rhs = contractOp.getRhs();
340 bool changed =
false;
341 for (
Value *operand : {&lhs, &rhs}) {
347 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
349 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
352 broadcast.getResultVectorType().getRank() - srcType.getRank();
353 bool innerDimBroadcast =
false;
356 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
357 rankDiff + dim.index())) {
358 innerDimBroadcast =
true;
361 originalDims.push_back(
366 if (innerDimBroadcast)
371 bool nonUnitDimReductionBroadcast =
false;
372 for (int64_t i = 0; i < rankDiff; ++i) {
373 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
376 nonUnitDimReductionBroadcast =
true;
380 if (nonUnitDimReductionBroadcast)
386 map = broadcastMap.
compose(map);
402 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
403 if (!unusedDimsBitVector.test(i))
404 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
411 bool hasReductionIteratorApplyingOnBothSides =
false;
412 for (
unsigned i = 0; i < iterators.size(); ++i) {
416 hasReductionIteratorApplyingOnBothSides =
true;
420 if (!hasReductionIteratorApplyingOnBothSides)
428 contractOp, lhs, rhs, contractOp.getAcc(),
447 struct ReorderCastOpsOnBroadcast
460 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
461 castResTy = vecTy.clone(castResTy);
464 bcastOp.getSource(), castResTy, op->
getAttrs());
485 struct ReorderElementwiseOpsOnTranspose final
501 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
503 transposeMaps.push_back(transposeOp.getPermutation());
504 srcType = transposeOp.getSourceVectorType();
509 if (transposeMaps.empty())
514 if (!llvm::all_equal(transposeMaps))
522 auto order = transposeMaps.front();
524 for (
int i = 0, e = order.size(); i < e; ++i)
525 invOrder[order[i]] = i;
528 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
530 srcValues.push_back(transposeOp.getVector());
534 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
535 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
536 operand.getLoc(), vectorType, operand, invOrder));
540 auto vectorType = srcType.clone(
547 transposeMaps.front());
554 return llvm::to_vector<4>(
555 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
556 [](IntegerAttr attr) { return attr.getInt(); }));
568 struct BubbleDownVectorBitCastForExtract
575 if (extractOp.getSourceVectorType().getRank() != 1)
578 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
582 VectorType castSrcType = castOp.getSourceVectorType();
583 VectorType castDstType = castOp.getResultVectorType();
584 assert(castSrcType.getRank() == castDstType.getRank());
589 if (castSrcType.getNumElements() == 1)
594 if (castSrcType.getNumElements() > castDstType.getNumElements())
597 unsigned expandRatio =
598 castDstType.getNumElements() / castSrcType.getNumElements();
601 assert(values[0].is<Attribute>() &&
"Unexpected non-constant index");
602 return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
610 Value packedValue = rewriter.
create<vector::ExtractOp>(
611 loc, castOp.getSource(), index / expandRatio);
614 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
615 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
620 VectorType packedType =
623 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
627 index % expandRatio);
644 struct BubbleDownBitCastForStridedSliceExtract
648 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
650 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
654 VectorType castSrcType = castOp.getSourceVectorType();
655 VectorType castDstType = castOp.getResultVectorType();
656 assert(castSrcType.getRank() == castDstType.getRank());
658 int64_t castSrcLastDim = castSrcType.getShape().back();
659 int64_t castDstLastDim = castDstType.getShape().back();
661 if (castSrcLastDim > castDstLastDim)
665 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666 [](
const APInt &val) { return !val.isOne(); }))
669 unsigned rank = extractOp.getSourceVectorType().getRank();
670 assert(castDstLastDim % castSrcLastDim == 0);
671 int64_t expandRatio = castDstLastDim / castSrcLastDim;
677 ArrayAttr newOffsets = extractOp.getOffsets();
678 if (newOffsets.size() == rank) {
680 if (offsets.back() % expandRatio != 0)
682 offsets.back() = offsets.back() / expandRatio;
687 ArrayAttr newSizes = extractOp.getSizes();
688 if (newSizes.size() == rank) {
690 if (sizes.back() % expandRatio != 0)
692 sizes.back() = sizes.back() / expandRatio;
697 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698 dims.back() = dims.back() / expandRatio;
699 VectorType newExtractType =
702 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
703 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
704 newSizes, extractOp.getStrides());
707 extractOp, extractOp.getType(), newExtractOp);
723 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
728 VectorType castSrcType = bitcastOp.getSourceVectorType();
729 VectorType castDstType = bitcastOp.getResultVectorType();
732 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733 castDstType.isScalable())
736 int64_t castSrcLastDim = castSrcType.getShape().back();
737 int64_t castDstLastDim = castDstType.getShape().back();
738 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
740 if (isNumElemsShrink) {
741 assert(castSrcLastDim % castDstLastDim == 0);
742 ratio = castSrcLastDim / castDstLastDim;
744 assert(castDstLastDim % castSrcLastDim == 0);
745 ratio = castDstLastDim / castSrcLastDim;
748 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
753 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
760 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761 VectorType newCastSrcType =
763 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
764 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
768 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
769 VectorType newCastDstType =
773 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
774 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
778 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
794 struct BubbleUpBitCastForStridedSliceInsert
800 VectorType castSrcType = bitcastOp.getSourceVectorType();
801 VectorType castDstType = bitcastOp.getResultVectorType();
802 assert(castSrcType.getRank() == castDstType.getRank());
804 if (castSrcType.getRank() == 0)
807 int64_t castSrcLastDim = castSrcType.getShape().back();
808 int64_t castDstLastDim = castDstType.getShape().back();
810 if (castSrcLastDim < castDstLastDim)
813 assert(castSrcLastDim % castDstLastDim == 0);
814 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
817 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
822 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
823 [](
const APInt &val) { return !val.isOne(); }))
826 unsigned rank = insertOp.getSourceVectorType().getRank();
829 if (rank != insertOp.getDestVectorType().getRank())
833 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
834 unsigned destinationWidth =
835 castDstType.getElementType().getIntOrFloatBitWidth();
836 unsigned numElements = destinationWidth / sourceWidth;
837 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
840 ArrayAttr newOffsets = insertOp.getOffsets();
841 assert(newOffsets.size() == rank);
843 if (offsets.back() % shrinkRatio != 0)
845 offsets.back() = offsets.back() / shrinkRatio;
849 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
850 srcDims.back() = srcDims.back() / shrinkRatio;
851 VectorType newCastSrcType =
854 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
855 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
858 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
859 dstDims.back() = dstDims.back() / shrinkRatio;
860 VectorType newCastDstType =
863 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
864 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
867 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
868 insertOp.getStrides());
892 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
897 std::function<
bool(vector::BitCastOp)> controlFn,
904 if (controlFn && !controlFn(bitcastOp))
907 VectorType castSrcType = bitcastOp.getSourceVectorType();
908 VectorType castDstType = bitcastOp.getResultVectorType();
909 assert(castSrcType.getRank() == castDstType.getRank());
912 if (castSrcType.getRank() != 1)
915 int64_t castSrcLastDim = castSrcType.getShape().back();
916 int64_t castDstLastDim = castDstType.getShape().back();
918 if (castSrcLastDim < castDstLastDim)
921 assert(castSrcLastDim % castDstLastDim == 0);
922 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
924 if (castSrcLastDim == shrinkRatio)
928 Type elemType = castDstType.getElementType();
933 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
937 VectorType newCastDstType =
939 castDstType.getElementType());
941 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
942 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
944 sliceShape, strides);
946 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
947 res = rewriter.
create<InsertStridedSliceOp>(
956 std::function<bool(BitCastOp)> controlFn;
973 struct ReorderElementwiseOpsOnBroadcast final
990 if (isa<vector::FMAOp>(op)) {
996 if (!lhsBcastOrSplat ||
997 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
999 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1006 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1008 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1009 auto splat = val.getDefiningOp<vector::SplatOp>();
1011 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1021 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1027 lhsBcastOrSplatType, op->
getAttrs());
1032 op, vectorType, elementwiseOp->
getResults());
1048 bool force32BitVectorIndices, int64_t dim,
1057 if (dim == 0 && force32BitVectorIndices) {
1060 }
else if (dim == 0) {
1063 }
else if (force32BitVectorIndices) {
1065 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1068 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1070 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1075 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1080 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1081 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1085 template <
typename ConcreteOp>
1088 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1091 force32BitVectorIndices(enableIndexOpt) {}
1095 if (!xferOp.hasOutOfBoundsDim())
1098 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1102 VectorType vtp = xferOp.getVectorType();
1109 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1110 Value off = xferOp.getIndices()[lastIndex];
1114 Value mask = rewriter.
create<vector::CreateMaskOp>(
1117 vtp.getScalableDims()),
1119 if (xferOp.getMask()) {
1121 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1125 xferOp.getMaskMutable().assign(mask);
1133 const bool force32BitVectorIndices;
1137 class VectorCreateMaskOpConversion
1140 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1141 bool enableIndexOpt,
1144 force32BitVectorIndices(enableIndexOpt) {}
1148 auto dstType = op.getType();
1149 if (cast<VectorType>(dstType).isScalable())
1151 int64_t rank = dstType.getRank();
1155 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1156 rank == 0 ? 0 : dstType.getDimSize(0),
1162 const bool force32BitVectorIndices;
1166 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1167 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1172 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1173 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1192 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1193 if (!vecType || !vecType.getElementType().isInteger(1))
1197 Value cond = selectOp.getCondition();
1198 if (isa<VectorType>(cond.
getType()))
1202 if (vecType.getRank() != 1 || vecType.isScalable())
1206 if (vecType.getShape()[0] != 1)
1209 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1210 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1214 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1215 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1219 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1234 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1244 int rankDiff = srcType.getRank() - vectorType.getRank();
1245 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1248 int dim = vectorType.getRank() - i - 1;
1249 if (srcStrides[dim + rankDiff] != 1 ||
1250 srcType.getDimSize(dim + rankDiff) != 1 ||
1251 vectorType.getDimSize(dim) != 1)
1259 class DropInnerMostUnitDimsTransferRead
1263 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1266 if (readOp.getTransferRank() == 0)
1270 if (readOp.getMask())
1273 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1277 if (!readOp.getPermutationMap().isMinorIdentity())
1280 auto targetType = readOp.getVectorType();
1281 if (targetType.getRank() <= 1)
1285 getTransferFoldableInnerUnitDims(srcType, targetType);
1286 if (
failed(maybeDimsToDrop))
1289 size_t dimsToDrop = maybeDimsToDrop.value();
1290 if (dimsToDrop == 0)
1293 auto resultTargetVecType =
1295 targetType.getElementType());
1297 auto loc = readOp.getLoc();
1304 auto resultMemrefType =
1305 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1306 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1308 ArrayAttr inBoundsAttr =
1309 readOp.getInBounds()
1311 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1313 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1314 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1316 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1317 Value result = rewriter.
create<vector::TransferReadOp>(
1318 loc, resultTargetVecType, rankedReducedView,
1320 readOp.getPadding(),
1322 Value(), inBoundsAttr);
1345 class DropInnerMostUnitDimsTransferWrite
1349 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1352 if (writeOp.getTransferRank() == 0)
1356 if (writeOp.getMask())
1359 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1363 if (!writeOp.getPermutationMap().isMinorIdentity())
1366 auto targetType = writeOp.getVectorType();
1367 if (targetType.getRank() <= 1)
1371 getTransferFoldableInnerUnitDims(srcType, targetType);
1372 if (
failed(maybeDimsToDrop))
1375 size_t dimsToDrop = maybeDimsToDrop.value();
1376 if (dimsToDrop == 0)
1379 auto resultTargetVecType =
1381 targetType.getElementType());
1390 auto resultMemrefType =
1391 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1392 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1394 ArrayAttr inBoundsAttr =
1395 writeOp.getInBounds()
1397 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1400 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1401 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1403 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1405 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1406 loc, resultTargetVecType, writeOp.getVector());
1408 writeOp, shapeCast, rankedReducedView,
1411 Value(), inBoundsAttr);
1419 struct CanonicalizeContractMatmulToMMT final
1423 using FilterConstraintType =
1427 FilterConstraintType constraint)
1429 filter(std::move(constraint)) {}
1437 Value lhs = op.getLhs();
1438 Value rhs = op.getRhs();
1439 Value res = op.getAcc();
1443 auto infer = [&](MapList m) {
1450 static constexpr std::array<int64_t, 2> perm = {1, 0};
1451 auto iteratorTypes = op.getIteratorTypes().getValue();
1453 if (iteratorTypes.size() != 3 ||
1460 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1461 if (maps == canonicalForm)
1466 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1467 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1469 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1470 VectorType newType =
1471 cast<VectorType>(trans.
getType())
1472 .
clone(cast<VectorType>(mat.getType()).getElementType());
1473 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1475 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1477 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1478 VectorType newType =
1480 cast<VectorType>(mat.getType()).getElementType());
1481 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1483 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1486 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1487 rhs = createTranspose(rhs);
1488 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1489 lhs = createTranspose(lhs);
1490 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1491 rhs = createTranspose(rhs);
1492 lhs = createTranspose(lhs);
1493 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1494 std::swap(rhs, lhs);
1495 rhs = createTranspose(rhs);
1496 lhs = createTranspose(lhs);
1497 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1498 std::swap(rhs, lhs);
1499 rhs = createTranspose(rhs);
1500 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1501 std::swap(lhs, rhs);
1502 lhs = createTranspose(lhs);
1503 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1504 std::swap(lhs, rhs);
1510 op.getIteratorTypes());
1515 FilterConstraintType filter;
1535 struct FoldArithExtIntoContractionOp
1539 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1542 auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1543 auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1545 if (!lhsDefOp || !rhsDefOp) {
1547 "no defining op on contract operands");
1551 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1552 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1553 contractOp.getIteratorTypesAttr());
1575 if (op.getKind() != vector::CombiningKind::ADD)
1579 Value acc = op.getAcc();
1586 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1587 if (!parentReduction)
1592 if (isa<IntegerType>(acc.
getType())) {
1594 loc, parentReduction.getVector(), op.getVector());
1596 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1600 parentReduction.getAcc());
1634 struct DropUnitDimFromElementwiseOps final
1643 if (!resultVectorType)
1650 if (sourceVectorType.getRank() < 2)
1653 bool hasTrailingDimUnitFixed =
1654 ((sourceVectorType.getShape().back() == 1) &&
1655 (!sourceVectorType.getScalableDims().back()));
1656 bool hasLeadingDimUnitFixed =
1657 ((sourceVectorType.getShape().front() == 1) &&
1658 (!sourceVectorType.getScalableDims().front()));
1659 if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1664 int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1669 auto opVectorType = cast<VectorType>(operand.getType());
1671 auto opSC = rewriter.
create<vector::ShapeCastOp>(loc, newVType, operand);
1672 newOperands.push_back(opSC);
1675 VectorType newResultVectorType =
1680 newResultVectorType, op->
getAttrs());
1710 if (op.getKind() != vector::CombiningKind::ADD)
1713 Type elemType = op.getSourceVectorType().getElementType();
1716 if (!isa<FloatType>(elemType))
1719 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1729 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
1747 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
1749 unsigned maxNumElementsToExtract,
1752 maxNumElementsToExtract(maxNumElementsToExtract) {}
1756 VectorType type = op.getSourceVectorType();
1757 if (type.isScalable() || op.isMasked())
1759 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
1761 int64_t numElems = type.getNumElements();
1762 if (numElems > maxNumElementsToExtract) {
1764 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
1765 "(max allowed: {1})",
1766 numElems, maxNumElementsToExtract));
1772 extractedElem = rewriter.
create<vector::ExtractOp>(
1773 loc, op.getVector(),
static_cast<int64_t
>(idx));
1775 Value res = extracted.front();
1776 for (
auto extractedElem : llvm::drop_begin(extracted))
1778 extractedElem, op.getFastmathAttr());
1779 if (
Value acc = op.getAcc())
1781 op.getFastmathAttr());
1788 unsigned maxNumElementsToExtract = 0;
1795 patterns.
add<FoldArithExtIntoContractionOp>(patterns.
getContext());
1801 patterns.
add<VectorCreateMaskOpConversion,
1802 MaterializeTransferMask<vector::TransferReadOp>,
1803 MaterializeTransferMask<vector::TransferWriteOp>>(
1804 patterns.
getContext(), force32BitVectorIndices, benefit);
1810 patterns.
add<ShapeCastOpFolder>(patterns.
getContext(), benefit);
1815 patterns.
add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1821 patterns.
add<BubbleDownVectorBitCastForExtract,
1822 BubbleDownBitCastForStridedSliceExtract,
1823 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
1829 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
1831 std::move(controlFn), benefit);
1836 std::function<
LogicalResult(vector::ContractionOp)> constraint,
1838 patterns.
add<CanonicalizeContractMatmulToMMT>(patterns.
getContext(), benefit,
1839 std::move(constraint));
1844 patterns.
add<MultiReduceToContract, CombineContractBroadcast,
1845 CombineContractABTranspose, CombineContractResultTranspose,
1846 ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1853 patterns.
add<DropInnerMostUnitDimsTransferRead,
1854 DropInnerMostUnitDimsTransferWrite>(patterns.
getContext(),
1860 patterns.
add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1866 patterns.
add<ChainedReduction>(patterns.
getContext(), benefit);
1874 patterns.
add<BreakDownVectorReduction>(patterns.
getContext(),
1875 maxNumElementsToExtract, benefit);
1882 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
type_range getType() const
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...