18 #include <type_traits>
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/raw_ostream.h"
49 #define DEBUG_TYPE "vector-to-vector"
54 template <
typename IntType>
56 return llvm::to_vector<4>(llvm::map_range(
57 arrayAttr.getAsRange<IntegerAttr>(),
58 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
92 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
95 auto sourceVectorType =
96 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
97 auto resultVectorType =
98 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
99 if (!sourceVectorType || !resultVectorType)
103 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
104 shapeCastOp.getSource().getDefiningOp());
105 if (!sourceShapeCastOp)
107 auto operandSourceVectorType =
108 cast<VectorType>(sourceShapeCastOp.getSource().getType());
109 auto operandResultVectorType = sourceShapeCastOp.getType();
112 if (operandSourceVectorType != resultVectorType ||
113 operandResultVectorType != sourceVectorType)
116 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
138 struct MultiReduceToContract
142 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
144 if (reduceOp.getKind() != vector::CombiningKind::ADD)
146 Operation *mulOp = reduceOp.getSource().getDefiningOp();
147 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
154 if (!isReduceDim.value()) {
155 iteratorTypes.push_back(vector::IteratorType::parallel);
158 iteratorTypes.push_back(vector::IteratorType::reduction);
162 0, exprs, reduceOp.getContext());
168 return IteratorTypeAttr::get(rewriter.getContext(), t);
197 struct CombineContractABTranspose final
201 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
204 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
205 Value lhs = contractOp.getLhs();
206 Value rhs = contractOp.getRhs();
208 bool changed =
false;
209 for (
Value *operand : {&lhs, &rhs}) {
211 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
215 transposeOp.getPermutation(), contractOp.getContext());
217 *operand = transposeOp.getVector();
223 contractOp, lhs, rhs, contractOp.getAcc(),
261 struct CombineContractResultTranspose final
267 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
268 if (!contractOp || !contractOp->hasOneUse())
271 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
276 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
288 auto combinedResMap = resTMap.compose(contractMap);
295 maps.back() = combinedResMap;
298 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
326 struct CombineContractBroadcast
330 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
333 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
334 Value lhs = contractOp.getLhs();
335 Value rhs = contractOp.getRhs();
337 bool changed =
false;
338 for (
Value *operand : {&lhs, &rhs}) {
344 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
346 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
349 broadcast.getResultVectorType().getRank() - srcType.getRank();
350 bool innerDimBroadcast =
false;
353 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
354 rankDiff + dim.index())) {
355 innerDimBroadcast =
true;
358 originalDims.push_back(
363 if (innerDimBroadcast)
368 bool nonUnitDimReductionBroadcast =
false;
369 for (int64_t i = 0; i < rankDiff; ++i) {
370 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
373 nonUnitDimReductionBroadcast =
true;
377 if (nonUnitDimReductionBroadcast)
383 map = broadcastMap.
compose(map);
399 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
400 if (!unusedDimsBitVector.test(i))
401 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
408 bool hasReductionIteratorApplyingOnBothSides =
false;
409 for (
unsigned i = 0; i < iterators.size(); ++i) {
413 hasReductionIteratorApplyingOnBothSides =
true;
417 if (!hasReductionIteratorApplyingOnBothSides)
425 contractOp, lhs, rhs, contractOp.getAcc(),
444 struct ReorderCastOpsOnBroadcast
457 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
458 castResTy = vecTy.clone(castResTy);
461 bcastOp.getSource(), castResTy, op->
getAttrs());
482 struct ReorderElementwiseOpsOnTranspose final
498 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
500 transposeMaps.push_back(transposeOp.getPermutation());
501 srcType = transposeOp.getSourceVectorType();
506 if (transposeMaps.empty())
511 if (!llvm::all_equal(transposeMaps))
519 auto order = transposeMaps.front();
521 for (
int i = 0, e = order.size(); i < e; ++i)
522 invOrder[order[i]] = i;
525 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
527 srcValues.push_back(transposeOp.getVector());
531 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
532 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
533 operand.getLoc(), vectorType, operand, invOrder));
537 auto vectorType = srcType.clone(
544 transposeMaps.front());
551 return llvm::to_vector<4>(
552 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
553 [](IntegerAttr attr) { return attr.getInt(); }));
565 struct BubbleDownVectorBitCastForExtract
572 if (extractOp.getSourceVectorType().getRank() != 1)
575 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
579 VectorType castSrcType = castOp.getSourceVectorType();
580 VectorType castDstType = castOp.getResultVectorType();
581 assert(castSrcType.getRank() == castDstType.getRank());
586 if (castSrcType.getNumElements() == 1)
591 if (castSrcType.getNumElements() > castDstType.getNumElements())
594 unsigned expandRatio =
595 castDstType.getNumElements() / castSrcType.getNumElements();
598 assert(values[0].is<Attribute>() &&
"Unexpected non-constant index");
599 return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
607 Value packedValue = rewriter.
create<vector::ExtractOp>(
608 loc, castOp.getSource(), index / expandRatio);
611 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
612 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
617 VectorType packedType =
620 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
624 index % expandRatio);
641 struct BubbleDownBitCastForStridedSliceExtract
645 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
647 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
651 VectorType castSrcType = castOp.getSourceVectorType();
652 VectorType castDstType = castOp.getResultVectorType();
653 assert(castSrcType.getRank() == castDstType.getRank());
655 int64_t castSrcLastDim = castSrcType.getShape().back();
656 int64_t castDstLastDim = castDstType.getShape().back();
658 if (castSrcLastDim > castDstLastDim)
662 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
663 [](
const APInt &val) { return !val.isOne(); }))
666 unsigned rank = extractOp.getSourceVectorType().getRank();
667 assert(castDstLastDim % castSrcLastDim == 0);
668 int64_t expandRatio = castDstLastDim / castSrcLastDim;
674 ArrayAttr newOffsets = extractOp.getOffsets();
675 if (newOffsets.size() == rank) {
677 if (offsets.back() % expandRatio != 0)
679 offsets.back() = offsets.back() / expandRatio;
684 ArrayAttr newSizes = extractOp.getSizes();
685 if (newSizes.size() == rank) {
687 if (sizes.back() % expandRatio != 0)
689 sizes.back() = sizes.back() / expandRatio;
694 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
695 dims.back() = dims.back() / expandRatio;
696 VectorType newExtractType =
699 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
700 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
701 newSizes, extractOp.getStrides());
704 extractOp, extractOp.getType(), newExtractOp);
721 struct BubbleUpBitCastForStridedSliceInsert
727 VectorType castSrcType = bitcastOp.getSourceVectorType();
728 VectorType castDstType = bitcastOp.getResultVectorType();
729 assert(castSrcType.getRank() == castDstType.getRank());
731 if (castSrcType.getRank() == 0)
734 int64_t castSrcLastDim = castSrcType.getShape().back();
735 int64_t castDstLastDim = castDstType.getShape().back();
737 if (castSrcLastDim < castDstLastDim)
740 assert(castSrcLastDim % castDstLastDim == 0);
741 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
744 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
749 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
750 [](
const APInt &val) { return !val.isOne(); }))
753 unsigned rank = insertOp.getSourceVectorType().getRank();
756 if (rank != insertOp.getDestVectorType().getRank())
760 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
761 unsigned destinationWidth =
762 castDstType.getElementType().getIntOrFloatBitWidth();
763 unsigned numElements = destinationWidth / sourceWidth;
764 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
767 ArrayAttr newOffsets = insertOp.getOffsets();
768 assert(newOffsets.size() == rank);
770 if (offsets.back() % shrinkRatio != 0)
772 offsets.back() = offsets.back() / shrinkRatio;
776 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
777 srcDims.back() = srcDims.back() / shrinkRatio;
778 VectorType newCastSrcType =
781 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
782 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
785 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
786 dstDims.back() = dstDims.back() / shrinkRatio;
787 VectorType newCastDstType =
790 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
791 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
794 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
795 insertOp.getStrides());
819 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
824 std::function<
bool(vector::BitCastOp)> controlFn,
831 if (controlFn && !controlFn(bitcastOp))
834 VectorType castSrcType = bitcastOp.getSourceVectorType();
835 VectorType castDstType = bitcastOp.getResultVectorType();
836 assert(castSrcType.getRank() == castDstType.getRank());
839 if (castSrcType.getRank() != 1)
842 int64_t castSrcLastDim = castSrcType.getShape().back();
843 int64_t castDstLastDim = castDstType.getShape().back();
845 if (castSrcLastDim < castDstLastDim)
848 assert(castSrcLastDim % castDstLastDim == 0);
849 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
851 if (castSrcLastDim == shrinkRatio)
855 Type elemType = castDstType.getElementType();
860 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
864 VectorType newCastDstType =
866 castDstType.getElementType());
868 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
869 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
871 sliceShape, strides);
873 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
874 res = rewriter.
create<InsertStridedSliceOp>(
883 std::function<bool(BitCastOp)> controlFn;
900 struct ReorderElementwiseOpsOnBroadcast final
917 if (isa<vector::FMAOp>(op)) {
923 if (!lhsBcastOrSplat ||
924 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
926 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
933 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
935 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
936 auto splat = val.getDefiningOp<vector::SplatOp>();
938 return (splat.getOperand().getType() == lhsBcastOrSplatType);
948 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
954 lhsBcastOrSplatType, op->
getAttrs());
975 bool force32BitVectorIndices, int64_t dim,
984 if (dim == 0 && force32BitVectorIndices) {
987 }
else if (dim == 0) {
990 }
else if (force32BitVectorIndices) {
992 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
995 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
997 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1002 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1007 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1008 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1012 template <
typename ConcreteOp>
1015 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1018 force32BitVectorIndices(enableIndexOpt) {}
1022 if (!xferOp.hasOutOfBoundsDim())
1025 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1029 VectorType vtp = xferOp.getVectorType();
1036 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1037 Value off = xferOp.getIndices()[lastIndex];
1041 Value mask = rewriter.
create<vector::CreateMaskOp>(
1044 vtp.getScalableDims()),
1046 if (xferOp.getMask()) {
1048 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1052 xferOp.getMaskMutable().assign(mask);
1060 const bool force32BitVectorIndices;
1064 class VectorCreateMaskOpConversion
1067 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1068 bool enableIndexOpt,
1071 force32BitVectorIndices(enableIndexOpt) {}
1075 auto dstType = op.getType();
1076 if (cast<VectorType>(dstType).isScalable())
1078 int64_t rank = dstType.getRank();
1082 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1083 rank == 0 ? 0 : dstType.getDimSize(0),
1089 const bool force32BitVectorIndices;
1093 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1094 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1099 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1100 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1119 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1120 if (!vecType || !vecType.getElementType().isInteger(1))
1124 Value cond = selectOp.getCondition();
1125 if (isa<VectorType>(cond.
getType()))
1129 if (vecType.getRank() != 1 || vecType.isScalable())
1133 if (vecType.getShape()[0] != 1)
1136 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1137 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1141 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1142 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1146 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1154 class DropInnerMostUnitDims :
public OpRewritePattern<vector::TransferReadOp> {
1157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1160 if (readOp.getTransferRank() == 0)
1164 if (readOp.getMask())
1167 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1168 if (!srcType || !srcType.hasStaticShape())
1171 if (!readOp.getPermutationMap().isMinorIdentity())
1174 auto targetType = readOp.getVectorType();
1175 if (targetType.getRank() <= 1)
1187 size_t dimsToDrop = 0;
1188 int rankDiff = srcType.getRank() - targetType.getRank();
1189 for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
1192 int dim = targetType.getRank() - i - 1;
1193 if (srcStrides[dim + rankDiff] == 1 &&
1194 srcType.getDimSize(dim + rankDiff) == 1 &&
1195 targetType.getDimSize(dim) == 1) {
1201 if (dimsToDrop == 0)
1204 auto resultTargetVecType =
1206 targetType.getElementType());
1208 MemRefType resultMemrefType;
1209 MemRefLayoutAttrInterface layout = srcType.getLayout();
1210 if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
1212 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1213 nullptr, srcType.getMemorySpace());
1215 MemRefLayoutAttrInterface updatedLayout;
1216 if (
auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1218 llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
1220 strided.getOffset(), strides);
1222 AffineMap map = srcType.getLayout().getAffineMap();
1224 for (
size_t i = 0; i < dimsToDrop; ++i) {
1225 int dim = srcType.getRank() - i - 1;
1232 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1233 updatedLayout, srcType.getMemorySpace());
1236 auto loc = readOp.getLoc();
1240 ArrayAttr inBoundsAttr =
1241 readOp.getInBounds()
1243 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1245 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1246 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
1249 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1250 Value result = rewriter.
create<vector::TransferReadOp>(
1251 loc, resultTargetVecType, rankedReducedView,
1253 readOp.getPadding(),
1255 Value(), inBoundsAttr);
1265 struct CanonicalizeContractMatmulToMMT final
1269 using FilterConstraintType =
1273 FilterConstraintType constraint)
1275 filter(std::move(constraint)) {}
1283 Value lhs = op.getLhs();
1284 Value rhs = op.getRhs();
1285 Value res = op.getAcc();
1294 static constexpr std::array<int64_t, 2> perm = {1, 0};
1295 auto iteratorTypes = op.getIteratorTypes().getValue();
1297 if (iteratorTypes.size() != 3 ||
1304 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1305 if (maps == canonicalForm)
1310 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1311 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1313 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1314 VectorType newType =
1315 cast<VectorType>(trans.
getType())
1316 .
clone(cast<VectorType>(mat.getType()).getElementType());
1317 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1319 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1321 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1322 VectorType newType =
1324 cast<VectorType>(mat.getType()).getElementType());
1325 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1327 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1330 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1331 rhs = createTranspose(rhs);
1332 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1333 lhs = createTranspose(lhs);
1334 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1335 rhs = createTranspose(rhs);
1336 lhs = createTranspose(lhs);
1337 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1338 std::swap(rhs, lhs);
1339 rhs = createTranspose(rhs);
1340 lhs = createTranspose(lhs);
1341 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1342 std::swap(rhs, lhs);
1343 rhs = createTranspose(rhs);
1344 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1345 std::swap(lhs, rhs);
1346 lhs = createTranspose(lhs);
1347 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1348 std::swap(lhs, rhs);
1354 op.getIteratorTypes());
1359 FilterConstraintType filter;
1379 struct FoldArithExtIntoContractionOp
1383 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1386 auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1387 auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1389 if (!lhsDefOp || !rhsDefOp) {
1391 "no defining op on contract operands");
1395 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1396 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1397 contractOp.getIteratorTypesAttr());
1419 if (op.getKind() != vector::CombiningKind::ADD)
1423 Value acc = op.getAcc();
1430 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1431 if (!parentReduction)
1436 if (isa<IntegerType>(acc.
getType())) {
1438 loc, parentReduction.getVector(), op.getVector());
1440 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1444 parentReduction.getAcc());
1468 if (op.getKind() != vector::CombiningKind::ADD)
1471 Type elemType = op.getSourceVectorType().getElementType();
1474 if (!isa<FloatType>(elemType))
1477 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1487 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
1499 patterns.
add<FoldArithExtIntoContractionOp>(patterns.
getContext());
1505 patterns.
add<VectorCreateMaskOpConversion,
1506 MaterializeTransferMask<vector::TransferReadOp>,
1507 MaterializeTransferMask<vector::TransferWriteOp>>(
1508 patterns.
getContext(), force32BitVectorIndices, benefit);
1514 patterns.
add<ShapeCastOpFolder>(patterns.
getContext(), benefit);
1519 patterns.
add<BubbleDownVectorBitCastForExtract,
1520 BubbleDownBitCastForStridedSliceExtract,
1521 BubbleUpBitCastForStridedSliceInsert>(patterns.
getContext(),
1527 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
1529 std::move(controlFn), benefit);
1534 std::function<
LogicalResult(vector::ContractionOp)> constraint,
1536 patterns.
add<CanonicalizeContractMatmulToMMT>(patterns.
getContext(), benefit,
1537 std::move(constraint));
1542 patterns.
add<MultiReduceToContract, CombineContractBroadcast,
1543 CombineContractABTranspose, CombineContractResultTranspose,
1544 ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1551 patterns.
add<DropInnerMostUnitDims>(patterns.
getContext(), benefit);
1556 patterns.
add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1562 patterns.
add<ChainedReduction>(patterns.
getContext(), benefit);
1571 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ValueRange values)
Returns the integer value from the first valid input element, assuming Value inputs are defined by a ...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...