35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/FormatVariadic.h"
38#define DEBUG_TYPE "vector-to-vector"
43template <
typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
79struct MultiReduceToContract
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84 PatternRewriter &rewriter)
const override {
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
90 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
92 SmallVector<AffineExpr> exprs;
93 SmallVector<vector::IteratorType> iteratorTypes;
94 for (
const auto &isReduceDim : llvm::enumerate(reductionMask)) {
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
99 iteratorTypes.push_back(vector::IteratorType::reduction);
104 0, exprs, reduceOp.getContext());
109 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
139struct CombineContractABTranspose final
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144 PatternRewriter &rewriter)
const override {
145 SmallVector<AffineMap> maps =
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value
lhs = contractOp.getLhs();
148 Value
rhs = contractOp.getRhs();
151 for (Value *operand : {&
lhs, &
rhs}) {
152 AffineMap &map = maps[index++];
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
157 transposeOp.getPermutation(), contractOp.getContext());
159 *operand = transposeOp.getVector();
165 contractOp,
lhs,
rhs, contractOp.getAcc(),
203struct CombineContractResultTranspose final
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208 PatternRewriter &rewriter)
const override {
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
217 MLIRContext *context = contractOp.getContext();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
219 AffineMap contractMap = maps.back();
230 auto combinedResMap = resTMap.compose(contractMap);
237 maps.back() = combinedResMap;
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
273FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
284 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
288 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
290 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast =
false;
296 for (
const auto &dim : llvm::enumerate(srcType.getShape())) {
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast =
true;
306 if (innerDimBroadcast)
311 bool nonUnitDimReductionBroadcast =
false;
312 for (
int64_t i = 0; i < rankDiff; ++i) {
313 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
316 nonUnitDimReductionBroadcast =
true;
320 if (nonUnitDimReductionBroadcast)
325 originalDims, contractOp.getContext());
326 map = broadcastMap.
compose(map);
342 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(i))
344 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit =
false;
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit =
true;
367 bool hasReductionIteratorApplyingOnBothSides =
false;
368 for (
unsigned i = 0; i < iterators.size(); ++i) {
372 hasReductionIteratorApplyingOnBothSides =
true;
376 if (!hasReductionIteratorApplyingOnBothSides)
384 Operation *newOp = vector::ContractionOp::create(
385 rewriter, contractOp.getLoc(),
lhs,
rhs, contractOp.getAcc(),
390 if (isAnyUnusedDimNonUnit)
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
412 VectorType::get(newShape, rewriter.
getI1Type(), newShapeScalableDims);
413 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
414 maskOpType, maskingOp.getMask())
423struct CombineContractBroadcastMask
425 using MaskableOpRewritePattern::MaskableOpRewritePattern;
428 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
429 MaskingOpInterface maskingOp,
430 PatternRewriter &rewriter)
const override {
431 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
448struct ReorderCastOpsOnBroadcast
450 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
452 LogicalResult matchAndRewrite(CastOpInterface op,
453 PatternRewriter &rewriter)
const override {
454 if (op->getNumOperands() != 1)
456 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
461 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
462 castResTy = vecTy.clone(castResTy);
464 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
465 bcastOp.getSource(), castResTy, op->getAttrs());
467 op, op->getResult(0).
getType(), castOp->getResult(0));
486struct ReorderElementwiseOpsOnTranspose final
489 LogicalResult matchAndRewrite(Operation *op,
490 PatternRewriter &rewriter)
const override {
496 SmallVector<ArrayRef<int64_t>> transposeMaps;
502 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
504 transposeMaps.push_back(transposeOp.getPermutation());
505 srcType = transposeOp.getSourceVectorType();
510 if (transposeMaps.empty())
515 if (!llvm::all_equal(transposeMaps))
518 SmallVector<Value> srcValues;
523 auto order = transposeMaps.front();
524 SmallVector<int64_t> invOrder(order.size());
525 for (
int i = 0, e = order.size(); i < e; ++i)
526 invOrder[order[i]] = i;
529 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
531 srcValues.push_back(transposeOp.getVector());
535 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
536 srcValues.push_back(vector::TransposeOp::create(
537 rewriter, operand.getLoc(), vectorType, operand, invOrder));
541 auto vectorType = srcType.clone(
543 Operation *elementwiseOp =
548 transposeMaps.front());
555 return llvm::to_vector<4>(
556 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
557 [](IntegerAttr attr) { return attr.getInt(); }));
569struct BubbleDownVectorBitCastForExtract
573 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574 PatternRewriter &rewriter)
const override {
576 if (extractOp.getSourceVectorType().getRank() != 1)
579 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
583 VectorType castSrcType = castOp.getSourceVectorType();
584 VectorType castDstType = castOp.getResultVectorType();
585 assert(castSrcType.getRank() == castDstType.getRank());
590 if (castSrcType.getNumElements() == 1)
595 if (castSrcType.getNumElements() > castDstType.getNumElements())
598 unsigned expandRatio =
599 castDstType.getNumElements() / castSrcType.getNumElements();
602 auto mixedPos = extractOp.getMixedPosition();
603 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
605 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
609 Location loc = extractOp.getLoc();
610 Value packedValue = vector::ExtractOp::create(
611 rewriter, loc, castOp.getSource(), index / expandRatio);
612 Type packedVecType = VectorType::get({1}, packedValue.
getType());
613 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
615 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
620 VectorType packedType =
621 VectorType::get({expandRatio}, castDstType.getElementType());
623 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
627 index % expandRatio);
644struct BubbleDownBitCastForStridedSliceExtract
648 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649 PatternRewriter &rewriter)
const override {
650 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
654 VectorType castSrcType = castOp.getSourceVectorType();
655 VectorType castDstType = castOp.getResultVectorType();
656 assert(castSrcType.getRank() == castDstType.getRank());
658 int64_t castSrcLastDim = castSrcType.getShape().back();
659 int64_t castDstLastDim = castDstType.getShape().back();
661 if (castSrcLastDim > castDstLastDim)
665 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666 [](
const APInt &val) { return !val.isOne(); }))
669 unsigned rank = extractOp.getSourceVectorType().getRank();
670 assert(castDstLastDim % castSrcLastDim == 0);
671 int64_t expandRatio = castDstLastDim / castSrcLastDim;
677 ArrayAttr newOffsets = extractOp.getOffsets();
678 if (newOffsets.size() == rank) {
679 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
680 if (offsets.back() % expandRatio != 0)
682 offsets.back() = offsets.back() / expandRatio;
687 ArrayAttr newSizes = extractOp.getSizes();
688 if (newSizes.size() == rank) {
689 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
690 if (sizes.back() % expandRatio != 0)
692 sizes.back() = sizes.back() / expandRatio;
696 SmallVector<int64_t> dims =
697 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698 dims.back() = dims.back() / expandRatio;
699 VectorType newExtractType =
700 VectorType::get(dims, castSrcType.getElementType());
702 auto newExtractOp = vector::ExtractStridedSliceOp::create(
703 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
704 newOffsets, newSizes, extractOp.getStrides());
707 extractOp, extractOp.getType(), newExtractOp);
723struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
726 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727 PatternRewriter &rewriter)
const override {
728 VectorType castSrcType = bitcastOp.getSourceVectorType();
729 VectorType castDstType = bitcastOp.getResultVectorType();
732 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733 castDstType.isScalable())
736 int64_t castSrcLastDim = castSrcType.getShape().back();
737 int64_t castDstLastDim = castDstType.getShape().back();
738 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
740 if (isNumElemsShrink) {
741 assert(castSrcLastDim % castDstLastDim == 0);
742 ratio = castSrcLastDim / castDstLastDim;
744 assert(castDstLastDim % castSrcLastDim == 0);
745 ratio = castDstLastDim / castSrcLastDim;
748 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
753 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
758 SmallVector<int64_t> srcDims(insertSrcType.getShape());
760 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761 VectorType newCastSrcType =
762 VectorType::get(srcDims, castDstType.getElementType());
764 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
765 insertOp.getValueToStore());
767 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
769 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770 VectorType newCastDstType =
771 VectorType::get(dstDims, castDstType.getElementType());
774 auto newCastDstOp = vector::BitCastOp::create(
775 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
779 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
795struct BubbleUpBitCastForStridedSliceInsert
799 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
800 PatternRewriter &rewriter)
const override {
801 VectorType castSrcType = bitcastOp.getSourceVectorType();
802 VectorType castDstType = bitcastOp.getResultVectorType();
803 assert(castSrcType.getRank() == castDstType.getRank());
805 if (castSrcType.getRank() == 0)
808 int64_t castSrcLastDim = castSrcType.getShape().back();
809 int64_t castDstLastDim = castDstType.getShape().back();
811 if (castSrcLastDim < castDstLastDim)
814 assert(castSrcLastDim % castDstLastDim == 0);
815 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
818 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
823 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824 [](
const APInt &val) { return !val.isOne(); }))
827 unsigned rank = insertOp.getSourceVectorType().getRank();
830 if (rank != insertOp.getDestVectorType().getRank())
834 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835 unsigned destinationWidth =
836 castDstType.getElementType().getIntOrFloatBitWidth();
837 unsigned numElements = destinationWidth / sourceWidth;
838 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
841 ArrayAttr newOffsets = insertOp.getOffsets();
842 assert(newOffsets.size() == rank);
843 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
844 if (offsets.back() % shrinkRatio != 0)
846 offsets.back() = offsets.back() / shrinkRatio;
849 SmallVector<int64_t> srcDims =
850 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
851 srcDims.back() = srcDims.back() / shrinkRatio;
852 VectorType newCastSrcType =
853 VectorType::get(srcDims, castDstType.getElementType());
856 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
857 insertOp.getValueToStore());
859 SmallVector<int64_t> dstDims =
860 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
861 dstDims.back() = dstDims.back() / shrinkRatio;
862 VectorType newCastDstType =
863 VectorType::get(dstDims, castDstType.getElementType());
865 auto newCastDstOp = vector::BitCastOp::create(
866 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
869 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
870 insertOp.getStrides());
898 BreakDownVectorBitCast(MLIRContext *context,
899 std::function<
bool(vector::BitCastOp)> controlFn,
900 PatternBenefit benefit)
901 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
903 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
904 PatternRewriter &rewriter)
const override {
906 if (controlFn && !controlFn(bitcastOp))
909 VectorType castSrcType = bitcastOp.getSourceVectorType();
910 VectorType castDstType = bitcastOp.getResultVectorType();
911 assert(castSrcType.getRank() == castDstType.getRank());
916 if (castSrcType.isScalable())
918 "Scalable vectors are not supported");
921 if (castSrcType.getRank() != 1)
924 int64_t castSrcLastDim = castSrcType.getShape().back();
925 int64_t castDstLastDim = castDstType.getShape().back();
927 if (castSrcLastDim < castDstLastDim)
930 assert(castSrcLastDim % castDstLastDim == 0);
931 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
933 if (castSrcLastDim == shrinkRatio)
936 Location loc = bitcastOp.getLoc();
937 Type elemType = castDstType.getElementType();
940 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
942 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
944 SmallVector<int64_t> sliceShape = {castDstLastDim};
945 SmallVector<int64_t> strides = {1};
946 VectorType newCastDstType =
947 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
948 castDstType.getElementType());
950 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
951 Value extracted = ExtractStridedSliceOp::create(
952 rewriter, loc, bitcastOp.getSource(),
953 ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
955 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
956 res = InsertStridedSliceOp::create(
957 rewriter, loc, bitcast, res,
958 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
965 std::function<bool(BitCastOp)> controlFn;
968static bool haveSameShapeAndScaling(
Type t,
Type u) {
969 auto tVec = dyn_cast<VectorType>(t);
970 auto uVec = dyn_cast<VectorType>(u);
977 return tVec.getShape() == uVec.getShape() &&
978 tVec.getScalableDims() == uVec.getScalableDims();
983static Type cloneOrReplace(
Type type,
Type newElementType) {
984 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
985 return shapedType.clone(newElementType);
987 return newElementType;
992static Value getBroadcastLikeSource(
Value value) {
998 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1017struct ReorderElementwiseOpsOnBroadcast final
1020 LogicalResult matchAndRewrite(Operation *op,
1021 PatternRewriter &rewriter)
const override {
1029 op,
"Op doesn't have ElementwiseMappableTraits");
1032 if (isa<vector::FMAOp>(op)) {
1035 "Op only accepts vector types - not supported as broadcast source "
1036 "might be a scalar");
1039 Type resultElemType = resultType.getElementType();
1042 Value broadcastSource;
1044 Operation *definingOp = operand.getDefiningOp();
1047 if (definingOp->
hasTrait<OpTrait::ConstantLike>())
1049 broadcastSource = getBroadcastLikeSource(operand);
1052 if (!broadcastSource)
1054 Type unbroadcastResultType =
1055 cloneOrReplace(broadcastSource.
getType(), resultElemType);
1061 if (!llvm::all_of(op->
getOperands(), [broadcastSource](Value val) {
1062 if (auto source = getBroadcastLikeSource(val))
1063 return haveSameShapeAndScaling(source.getType(),
1064 broadcastSource.getType());
1065 SplatElementsAttr splatConst;
1066 return matchPattern(val, m_Constant(&splatConst));
1070 "not all operands are constants or broadcasts from the same type");
1074 SmallVector<Value> srcValues;
1077 SplatElementsAttr splatConst;
1081 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1082 if (
auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1087 Operation *newConstOp =
1088 operand.getDefiningOp()->getDialect()->materializeConstant(
1089 rewriter, newConst, newType, operand.getLoc());
1090 srcValues.push_back(newConstOp->
getResult(0));
1092 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1097 Operation *elementwiseOp =
1099 unbroadcastResultType, op->
getAttrs());
1103 op, resultType, elementwiseOp->
getResults());
1125class ExtractOpFromElementwise final
1130 LogicalResult matchAndRewrite(vector::ExtractOp op,
1131 PatternRewriter &rewriter)
const override {
1132 Operation *eltwise = op.getSource().getDefiningOp();
1137 isa<vector::FMAOp>(eltwise))
1151 if (!op.getDynamicPosition().empty())
1153 op,
"dynamic position not yet implemented");
1155 Type dstType = op.getType();
1157 OpBuilder::InsertionGuard g(rewriter);
1161 Location loc = eltwise->
getLoc();
1162 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1164 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1165 mapping.
map(arg, newArg);
1168 Operation *newEltwise = rewriter.
clone(*eltwise, mapping);
1179static bool isSupportedMemSinkElementType(
Type type) {
1180 if (isa<IndexType>(type))
1201class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1206 PatternRewriter &rewriter)
const override {
1207 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1212 if (!loadOp->hasOneUse())
1215 VectorType loadVecType = loadOp.getVectorType();
1216 if (loadVecType.isScalable())
1218 "scalable vectors are not supported");
1220 MemRefType memType = loadOp.getMemRefType();
1224 if (!isSupportedMemSinkElementType(memType.getElementType()))
1227 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1231 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1232 int64_t finalRank = 0;
1234 finalRank = extractVecType.getRank();
1236 SmallVector<Value>
indices = loadOp.getIndices();
1237 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1242 OpBuilder::InsertionGuard g(rewriter);
1244 Location loc = loadOp.getLoc();
1245 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1246 for (
auto i : llvm::seq<int64_t>(rankOffset,
indices.size() - finalRank)) {
1247 OpFoldResult pos = extractPos[i - rankOffset];
1255 Value base = loadOp.getBase();
1256 if (extractVecType) {
1279class StoreOpFromBroadcast final :
public OpRewritePattern<vector::StoreOp> {
1284 PatternRewriter &rewriter)
const override {
1285 VectorType vecType = op.getVectorType();
1286 if (vecType.isScalable())
1288 "scalable vectors are not supported");
1290 if (isa<VectorType>(op.getMemRefType().getElementType()))
1292 op,
"memrefs of vectors are not supported");
1294 if (vecType.getNumElements() != 1)
1296 op,
"only 1-element vectors are supported");
1298 Value toStore = op.getValueToStore();
1299 Value source = getBroadcastLikeSource(toStore);
1302 op,
"value to store is not from a broadcast");
1309 Value base = op.getBase();
1312 if (isa<VectorType>(source.
getType())) {
1332 bool force32BitVectorIndices,
int64_t dim,
1341 if (dim == 0 && force32BitVectorIndices) {
1344 }
else if (dim == 0) {
1347 }
else if (force32BitVectorIndices) {
1349 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1352 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1354 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1358 Value ov = vector::BroadcastOp::create(rewriter, loc,
indices.getType(), o);
1364 vector::BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1365 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1369template <
typename ConcreteOp>
1372 explicit MaterializeTransferMask(MLIRContext *context,
bool enableIndexOpt,
1373 PatternBenefit benefit = 1)
1374 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1375 force32BitVectorIndices(enableIndexOpt) {}
1378 PatternRewriter &rewriter)
const override {
1379 if (!xferOp.hasOutOfBoundsDim())
1382 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1385 Location loc = xferOp->getLoc();
1386 VectorType vtp = xferOp.getVectorType();
1393 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1394 Value off = xferOp.getIndices()[lastIndex];
1397 Value
b = arith::SubIOp::create(rewriter, loc, dim.
getType(), dim, off);
1398 Value mask = vector::CreateMaskOp::create(
1400 VectorType::get(vtp.getShape(), rewriter.
getI1Type(),
1401 vtp.getScalableDims()),
1403 if (xferOp.getMask()) {
1405 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1409 xferOp.getMaskMutable().assign(mask);
1417 const bool force32BitVectorIndices;
1421class VectorCreateMaskOpConversion
1424 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1425 bool enableIndexOpt,
1426 PatternBenefit benefit = 1)
1427 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1428 force32BitVectorIndices(enableIndexOpt) {}
1430 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1431 PatternRewriter &rewriter)
const override {
1432 auto dstType = op.getType();
1433 if (cast<VectorType>(dstType).isScalable())
1435 int64_t rank = dstType.getRank();
1439 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1440 rank == 0 ? 0 : dstType.getDimSize(0),
1446 const bool force32BitVectorIndices;
1450static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1451 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1456 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1457 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1475 PatternRewriter &rewriter)
const override {
1476 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1477 if (!vecType || !vecType.getElementType().isInteger(1))
1481 Value cond = selectOp.getCondition();
1482 if (isa<VectorType>(cond.
getType()))
1486 if (vecType.getRank() != 1 || vecType.isScalable())
1490 if (vecType.getShape()[0] != 1)
1493 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1494 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1498 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1499 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1503 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1504 auto bcastType = VectorType::get({1}, elemType);
1525static FailureOr<size_t>
1529 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1532 auto isUnitDim = [](VectorType type,
int dim) {
1533 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1540 int rankDiff = srcType.getRank() - vectorType.getRank();
1541 for (
int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1544 int dim = vectorType.getRank() - i - 1;
1545 if (srcStrides[dim + rankDiff] != 1 ||
1546 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1558 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1561 if (readOp.getTransferRank() == 0)
1565 if (readOp.getMask())
1568 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1572 if (!readOp.getPermutationMap().isMinorIdentity())
1575 auto targetType = readOp.getVectorType();
1576 if (targetType.getRank() <= 1)
1579 FailureOr<size_t> maybeDimsToDrop =
1581 if (failed(maybeDimsToDrop))
1584 size_t dimsToDrop = maybeDimsToDrop.value();
1585 if (dimsToDrop == 0)
1588 auto inBounds = readOp.getInBoundsValues();
1589 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1590 if (llvm::is_contained(droppedInBounds,
false))
1593 auto resultTargetVecType =
1594 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1595 targetType.getElementType(),
1596 targetType.getScalableDims().drop_back(dimsToDrop));
1598 auto loc = readOp.getLoc();
1605 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1606 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1609 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1610 Value rankedReducedView =
1611 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1612 readOp.getBase(), offsets, sizes, strides);
1614 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1616 rewriter, loc, resultTargetVecType, rankedReducedView,
1617 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1618 readOp.getPadding(),
1620 Value(), inBoundsAttr);
1649 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1652 if (writeOp.getTransferRank() == 0)
1656 if (writeOp.getMask())
1659 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1663 if (!writeOp.getPermutationMap().isMinorIdentity())
1666 auto targetType = writeOp.getVectorType();
1667 if (targetType.getRank() <= 1)
1670 FailureOr<size_t> maybeDimsToDrop =
1672 if (failed(maybeDimsToDrop))
1675 size_t dimsToDrop = maybeDimsToDrop.value();
1676 if (dimsToDrop == 0)
1679 auto inBounds = writeOp.getInBoundsValues();
1680 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1681 if (llvm::is_contained(droppedInBounds,
false))
1684 auto resultTargetVecType =
1685 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1686 targetType.getElementType(),
1687 targetType.getScalableDims().drop_back(dimsToDrop));
1696 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1697 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1700 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1702 Value rankedReducedView =
1703 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1704 writeOp.getBase(), offsets, sizes, strides);
1706 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1708 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1709 loc, resultTargetVecType, writeOp.getVector());
1711 writeOp, shapeCast, rankedReducedView,
1712 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1714 Value(), inBoundsAttr);
1727 std::function<LogicalResult(vector::ContractionOp op)>;
1732 filter(std::move(constraint)) {}
1736 if (failed(filter(op)))
1742 Value res = op.getAcc();
1746 auto infer = [&](MapList m) {
1753 static constexpr std::array<int64_t, 2> perm = {1, 0};
1754 auto iteratorTypes = op.getIteratorTypes().getValue();
1756 if (iteratorTypes.size() != 3 ||
1763 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1764 if (maps == canonicalForm)
1769 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1770 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1772 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1773 VectorType newType =
1774 cast<VectorType>(trans.
getType())
1775 .clone(cast<VectorType>(mat.getType()).getElementType());
1776 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1778 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1780 vector::TransposeOp::create(rewriter, loc,
zext.getIn(), perm);
1781 VectorType newType =
1782 VectorType::get(cast<VectorType>(trans.
getType()).getShape(),
1783 cast<VectorType>(mat.getType()).getElementType());
1784 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1786 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1789 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1790 rhs = createTranspose(
rhs);
1791 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1792 lhs = createTranspose(
lhs);
1793 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1794 rhs = createTranspose(
rhs);
1795 lhs = createTranspose(
lhs);
1796 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1798 rhs = createTranspose(
rhs);
1799 lhs = createTranspose(
lhs);
1800 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1802 rhs = createTranspose(
rhs);
1803 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1805 lhs = createTranspose(
lhs);
1806 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1813 op.getIteratorTypes());
1838template <
typename ExtOp>
1846 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1847 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1849 if (!lhsDefOp || !rhsDefOp) {
1851 "no defining op on contract operands");
1855 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1856 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1857 contractOp.getIteratorTypesAttr());
1879 if (op.getKind() != vector::CombiningKind::ADD)
1887 if (!
acc.getType().isIntOrFloat())
1890 auto parentReduction =
acc.getDefiningOp<vector::ReductionOp>();
1891 if (!parentReduction)
1896 if (isa<IntegerType>(
acc.getType())) {
1898 loc, parentReduction.getVector(), op.getVector());
1900 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1904 parentReduction.getAcc());
1915 auto inVecShape = inVecTy.getShape();
1918 for (
auto [dim, isScalable] :
1919 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1920 if (dim == 1 && !isScalable)
1923 newShape.push_back(dim);
1924 newScalableDims.push_back(isScalable);
1927 if (newShape.empty()) {
1928 newShape.push_back(1);
1929 newScalableDims.push_back(
false);
1932 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1969 if (!resultVectorType)
1976 if (!sourceVectorType)
1978 if (sourceVectorType.getRank() < 2)
1984 auto opVectorType = cast<VectorType>(operand.getType());
1986 if (newVType == opVectorType)
1989 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
1990 newOperands.push_back(opSC);
1993 VectorType newResultVectorType =
1998 newResultVectorType, op->
getAttrs());
2033 VectorType sourceType = op.getSourceVectorType();
2034 VectorType sourceTypeWithoutUnitDims =
2037 if (sourceType == sourceTypeWithoutUnitDims)
2044 for (
auto [i, dim] : llvm::enumerate(sourceDims)) {
2045 droppedDimsBefore[i] = droppedDims;
2046 if (dim == std::make_tuple(1,
false))
2054 if (sourceDims[idx] == std::make_tuple(1,
false))
2056 newPerm.push_back(idx - droppedDimsBefore[idx]);
2062 if (newPerm.empty()) {
2063 newPerm.push_back(0);
2068 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2069 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2071 auto transposeWithoutUnitDims =
2072 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2075 op, op.getResultVectorType(), transposeWithoutUnitDims);
2112 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
2113 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2118 if (vectorType == newVectorType)
2123 return vector::ShapeCastOp::create(
b, loc, type, source);
2127 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2129 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2156 if (op.getKind() != vector::CombiningKind::ADD)
2159 Type elemType = op.getSourceVectorType().getElementType();
2162 if (!isa<FloatType>(elemType))
2165 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2168 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2175 auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2176 addLhs.getLhs(), vAdd.getRhs());
2195 unsigned maxNumElementsToExtract,
2198 maxNumElementsToExtract(maxNumElementsToExtract) {}
2202 VectorType type = op.getSourceVectorType();
2203 if (type.isScalable() || op.isMasked())
2205 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2207 int64_t numElems = type.getNumElements();
2208 if (numElems > maxNumElementsToExtract) {
2210 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2211 "(max allowed: {1})",
2212 numElems, maxNumElementsToExtract));
2217 for (
auto [idx, extractedElem] : llvm::enumerate(extracted))
2218 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2221 Value res = extracted.front();
2222 for (
auto extractedElem : llvm::drop_begin(extracted))
2224 extractedElem, op.getFastmathAttr());
2227 op.getFastmathAttr());
2234 unsigned maxNumElementsToExtract = 0;
2253template <
typename MulOpType>
2258 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2261 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2264 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2265 return srcType && srcType.getRank() != 2;
2270 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2273 if (resType.getRank() != 2)
2278 auto matchOuterProduct =
2280 Value operandB) -> FailureOr<vector::OuterProductOp> {
2281 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2286 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2289 auto broadcastedLhs =
2290 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2291 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2294 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2295 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2298 return vector::OuterProductOp::create(
2299 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2300 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2303 Value lhs = mulOp->getOperand(0),
rhs = mulOp->getOperand(1);
2304 auto maybeOuterP = matchOuterProduct(
lhs,
rhs);
2306 if (failed(maybeOuterP))
2307 maybeOuterP = matchOuterProduct(
rhs,
lhs);
2308 if (failed(maybeOuterP))
2310 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2324void mlir::vector::populateVectorMaskMaterializationPatterns(
2327 patterns.add<VectorCreateMaskOpConversion,
2328 MaterializeTransferMask<vector::TransferReadOp>,
2329 MaterializeTransferMask<vector::TransferWriteOp>>(
2330 patterns.getContext(), force32BitVectorIndices, benefit);
2334void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2340void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2342 patterns.add<BubbleDownVectorBitCastForExtract,
2343 BubbleDownBitCastForStridedSliceExtract,
2344 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2348void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2350 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2352 std::move(controlFn), benefit);
2357 std::function<LogicalResult(vector::ContractionOp)> constraint,
2360 std::move(constraint));
2365 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2366 CombineContractABTranspose, CombineContractResultTranspose>(
2379 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2380 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2387 patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(
patterns.getContext(),
2391void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2398void mlir::vector::populateBreakDownVectorReductionPatterns(
2402 maxNumElementsToExtract, benefit);
2407 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2408 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2416#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Drop inner most contiguous unit dimensions from transfer_read operand.
Drop inner most contiguous unit dimensions from transfer_write operand. E.g., vector....
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
BreakDownVectorReduction(MLIRContext *context, unsigned maxNumElementsToExtract, PatternBenefit benefit)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction sui...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, FilterConstraintType constraint)
Pattern to fold chained reduction to a series of vector additions and a final reduction....
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b =...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
A pattern to drop unit dims from the iter_args of an scf.for.
LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override
A pattern to drop unit dims from vector.transpose.
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Pattern to fold arithmetic extensions on floating point data types into vector contraction operations...
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
Pattern to eliminate redundant zero-constants added to reduction operands. It's enough for there to b...
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.