35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/FormatVariadic.h"
38#define DEBUG_TYPE "vector-to-vector"
43template <
typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
79struct MultiReduceToContract
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84 PatternRewriter &rewriter)
const override {
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
90 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
92 SmallVector<AffineExpr> exprs;
93 SmallVector<vector::IteratorType> iteratorTypes;
94 for (
const auto &isReduceDim : llvm::enumerate(reductionMask)) {
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
99 iteratorTypes.push_back(vector::IteratorType::reduction);
104 0, exprs, reduceOp.getContext());
109 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
139struct CombineContractABTranspose final
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144 PatternRewriter &rewriter)
const override {
145 SmallVector<AffineMap> maps =
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value
lhs = contractOp.getLhs();
148 Value
rhs = contractOp.getRhs();
151 for (Value *operand : {&
lhs, &
rhs}) {
152 AffineMap &map = maps[index++];
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
157 transposeOp.getPermutation(), contractOp.getContext());
159 *operand = transposeOp.getVector();
165 contractOp,
lhs,
rhs, contractOp.getAcc(),
203struct CombineContractResultTranspose final
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208 PatternRewriter &rewriter)
const override {
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
217 MLIRContext *context = contractOp.getContext();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
219 AffineMap contractMap = maps.back();
230 auto combinedResMap = resTMap.compose(contractMap);
237 maps.back() = combinedResMap;
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
273FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
284 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
288 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
290 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast =
false;
296 for (
const auto &dim : llvm::enumerate(srcType.getShape())) {
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast =
true;
306 if (innerDimBroadcast)
311 bool nonUnitDimReductionBroadcast =
false;
312 for (
int64_t i = 0; i < rankDiff; ++i) {
313 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
316 nonUnitDimReductionBroadcast =
true;
320 if (nonUnitDimReductionBroadcast)
325 originalDims, contractOp.getContext());
326 map = broadcastMap.
compose(map);
342 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(i))
344 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit =
false;
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit =
true;
367 bool hasReductionIteratorApplyingOnBothSides =
false;
368 for (
unsigned i = 0; i < iterators.size(); ++i) {
372 hasReductionIteratorApplyingOnBothSides =
true;
376 if (!hasReductionIteratorApplyingOnBothSides)
384 Operation *newOp = vector::ContractionOp::create(
385 rewriter, contractOp.getLoc(),
lhs,
rhs, contractOp.getAcc(),
390 if (isAnyUnusedDimNonUnit)
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
412 VectorType::get(newShape, rewriter.
getI1Type(), newShapeScalableDims);
413 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
414 maskOpType, maskingOp.getMask())
423struct CombineContractBroadcastMask
425 using MaskableOpRewritePattern::MaskableOpRewritePattern;
428 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
429 MaskingOpInterface maskingOp,
430 PatternRewriter &rewriter)
const override {
431 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
448struct ReorderCastOpsOnBroadcast
450 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
452 LogicalResult matchAndRewrite(CastOpInterface op,
453 PatternRewriter &rewriter)
const override {
454 if (op->getNumOperands() != 1)
456 if (!isa<VectorType>(op->getResult(0).getType()))
458 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
463 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
464 castResTy = vecTy.clone(castResTy);
466 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
467 bcastOp.getSource(), castResTy, op->getAttrs());
469 op, op->getResult(0).
getType(), castOp->getResult(0));
488struct ReorderElementwiseOpsOnTranspose final
491 LogicalResult matchAndRewrite(Operation *op,
492 PatternRewriter &rewriter)
const override {
498 SmallVector<ArrayRef<int64_t>> transposeMaps;
504 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
506 transposeMaps.push_back(transposeOp.getPermutation());
507 srcType = transposeOp.getSourceVectorType();
512 if (transposeMaps.empty())
517 if (!llvm::all_equal(transposeMaps))
520 SmallVector<Value> srcValues;
525 auto order = transposeMaps.front();
526 SmallVector<int64_t> invOrder(order.size());
527 for (
int i = 0, e = order.size(); i < e; ++i)
528 invOrder[order[i]] = i;
531 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
533 srcValues.push_back(transposeOp.getVector());
537 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
538 srcValues.push_back(vector::TransposeOp::create(
539 rewriter, operand.getLoc(), vectorType, operand, invOrder));
543 auto vectorType = srcType.clone(
545 Operation *elementwiseOp =
550 transposeMaps.front());
557 return llvm::to_vector<4>(
558 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
559 [](IntegerAttr attr) { return attr.getInt(); }));
571struct BubbleDownVectorBitCastForExtract
575 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
576 PatternRewriter &rewriter)
const override {
578 if (extractOp.getSourceVectorType().getRank() != 1)
581 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
585 VectorType castSrcType = castOp.getSourceVectorType();
586 VectorType castDstType = castOp.getResultVectorType();
587 assert(castSrcType.getRank() == castDstType.getRank());
592 if (castSrcType.getNumElements() == 1)
597 if (castSrcType.getNumElements() > castDstType.getNumElements())
600 unsigned expandRatio =
601 castDstType.getNumElements() / castSrcType.getNumElements();
604 auto mixedPos = extractOp.getMixedPosition();
605 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
607 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
611 Location loc = extractOp.getLoc();
612 Value packedValue = vector::ExtractOp::create(
613 rewriter, loc, castOp.getSource(), index / expandRatio);
614 Type packedVecType = VectorType::get({1}, packedValue.
getType());
615 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
617 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
622 VectorType packedType =
623 VectorType::get({expandRatio}, castDstType.getElementType());
625 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
629 index % expandRatio);
646struct BubbleDownBitCastForStridedSliceExtract
650 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
651 PatternRewriter &rewriter)
const override {
652 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
656 VectorType castSrcType = castOp.getSourceVectorType();
657 VectorType castDstType = castOp.getResultVectorType();
658 assert(castSrcType.getRank() == castDstType.getRank());
660 int64_t castSrcLastDim = castSrcType.getShape().back();
661 int64_t castDstLastDim = castDstType.getShape().back();
663 if (castSrcLastDim > castDstLastDim)
667 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
668 [](
const APInt &val) { return !val.isOne(); }))
671 unsigned rank = extractOp.getSourceVectorType().getRank();
672 assert(castDstLastDim % castSrcLastDim == 0);
673 int64_t expandRatio = castDstLastDim / castSrcLastDim;
679 ArrayAttr newOffsets = extractOp.getOffsets();
680 if (newOffsets.size() == rank) {
681 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
682 if (offsets.back() % expandRatio != 0)
684 offsets.back() = offsets.back() / expandRatio;
689 ArrayAttr newSizes = extractOp.getSizes();
690 if (newSizes.size() == rank) {
691 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
692 if (sizes.back() % expandRatio != 0)
694 sizes.back() = sizes.back() / expandRatio;
698 SmallVector<int64_t> dims =
699 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
700 dims.back() = dims.back() / expandRatio;
701 VectorType newExtractType =
702 VectorType::get(dims, castSrcType.getElementType());
704 auto newExtractOp = vector::ExtractStridedSliceOp::create(
705 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
706 newOffsets, newSizes, extractOp.getStrides());
709 extractOp, extractOp.getType(), newExtractOp);
725struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
728 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
729 PatternRewriter &rewriter)
const override {
730 VectorType castSrcType = bitcastOp.getSourceVectorType();
731 VectorType castDstType = bitcastOp.getResultVectorType();
734 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
735 castDstType.isScalable())
738 int64_t castSrcLastDim = castSrcType.getShape().back();
739 int64_t castDstLastDim = castDstType.getShape().back();
740 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
742 if (isNumElemsShrink) {
743 assert(castSrcLastDim % castDstLastDim == 0);
744 ratio = castSrcLastDim / castDstLastDim;
746 assert(castDstLastDim % castSrcLastDim == 0);
747 ratio = castDstLastDim / castSrcLastDim;
750 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
755 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
760 SmallVector<int64_t> srcDims(insertSrcType.getShape());
762 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
763 VectorType newCastSrcType =
764 VectorType::get(srcDims, castDstType.getElementType());
766 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
767 insertOp.getValueToStore());
769 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
771 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
772 VectorType newCastDstType =
773 VectorType::get(dstDims, castDstType.getElementType());
776 auto newCastDstOp = vector::BitCastOp::create(
777 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
781 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
797struct BubbleUpBitCastForStridedSliceInsert
801 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
802 PatternRewriter &rewriter)
const override {
803 VectorType castSrcType = bitcastOp.getSourceVectorType();
804 VectorType castDstType = bitcastOp.getResultVectorType();
805 assert(castSrcType.getRank() == castDstType.getRank());
807 if (castSrcType.getRank() == 0)
810 int64_t castSrcLastDim = castSrcType.getShape().back();
811 int64_t castDstLastDim = castDstType.getShape().back();
813 if (castSrcLastDim < castDstLastDim)
816 assert(castSrcLastDim % castDstLastDim == 0);
817 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
820 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
825 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
826 [](
const APInt &val) { return !val.isOne(); }))
829 unsigned rank = insertOp.getSourceVectorType().getRank();
832 if (rank != insertOp.getDestVectorType().getRank())
836 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
837 unsigned destinationWidth =
838 castDstType.getElementType().getIntOrFloatBitWidth();
839 unsigned numElements = destinationWidth / sourceWidth;
840 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
843 ArrayAttr newOffsets = insertOp.getOffsets();
844 assert(newOffsets.size() == rank);
845 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
846 if (offsets.back() % shrinkRatio != 0)
848 offsets.back() = offsets.back() / shrinkRatio;
851 SmallVector<int64_t> srcDims =
852 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
853 srcDims.back() = srcDims.back() / shrinkRatio;
854 VectorType newCastSrcType =
855 VectorType::get(srcDims, castDstType.getElementType());
858 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
859 insertOp.getValueToStore());
861 SmallVector<int64_t> dstDims =
862 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
863 dstDims.back() = dstDims.back() / shrinkRatio;
864 VectorType newCastDstType =
865 VectorType::get(dstDims, castDstType.getElementType());
867 auto newCastDstOp = vector::BitCastOp::create(
868 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
871 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
872 insertOp.getStrides());
900 BreakDownVectorBitCast(MLIRContext *context,
901 std::function<
bool(vector::BitCastOp)> controlFn,
902 PatternBenefit benefit)
903 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
905 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
906 PatternRewriter &rewriter)
const override {
908 if (controlFn && !controlFn(bitcastOp))
911 VectorType castSrcType = bitcastOp.getSourceVectorType();
912 VectorType castDstType = bitcastOp.getResultVectorType();
913 assert(castSrcType.getRank() == castDstType.getRank());
918 if (castSrcType.isScalable())
920 "Scalable vectors are not supported");
923 if (castSrcType.getRank() != 1)
926 int64_t castSrcLastDim = castSrcType.getShape().back();
927 int64_t castDstLastDim = castDstType.getShape().back();
929 if (castSrcLastDim < castDstLastDim)
932 assert(castSrcLastDim % castDstLastDim == 0);
933 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
935 if (castSrcLastDim == shrinkRatio)
938 Location loc = bitcastOp.getLoc();
939 Type elemType = castDstType.getElementType();
942 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
944 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
946 SmallVector<int64_t> sliceShape = {castDstLastDim};
947 SmallVector<int64_t> strides = {1};
948 VectorType newCastDstType =
949 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
950 castDstType.getElementType());
952 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
953 Value extracted = ExtractStridedSliceOp::create(
954 rewriter, loc, bitcastOp.getSource(),
955 ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
957 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
958 res = InsertStridedSliceOp::create(
959 rewriter, loc, bitcast, res,
960 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
967 std::function<bool(BitCastOp)> controlFn;
970static bool haveSameShapeAndScaling(
Type t,
Type u) {
971 auto tVec = dyn_cast<VectorType>(t);
972 auto uVec = dyn_cast<VectorType>(u);
979 return tVec.getShape() == uVec.getShape() &&
980 tVec.getScalableDims() == uVec.getScalableDims();
985static Type cloneOrReplace(
Type type,
Type newElementType) {
986 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
987 return shapedType.clone(newElementType);
989 return newElementType;
994static Value getBroadcastLikeSource(
Value value) {
1000 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1019struct ReorderElementwiseOpsOnBroadcast final
1022 LogicalResult matchAndRewrite(Operation *op,
1023 PatternRewriter &rewriter)
const override {
1031 op,
"Op doesn't have ElementwiseMappableTraits");
1034 if (isa<vector::FMAOp>(op)) {
1037 "Op only accepts vector types - not supported as broadcast source "
1038 "might be a scalar");
1041 Type resultElemType = resultType.getElementType();
1044 Value broadcastSource;
1046 Operation *definingOp = operand.getDefiningOp();
1049 if (definingOp->
hasTrait<OpTrait::ConstantLike>())
1051 broadcastSource = getBroadcastLikeSource(operand);
1054 if (!broadcastSource)
1056 Type unbroadcastResultType =
1057 cloneOrReplace(broadcastSource.
getType(), resultElemType);
1063 if (!llvm::all_of(op->
getOperands(), [broadcastSource](Value val) {
1064 if (auto source = getBroadcastLikeSource(val))
1065 return haveSameShapeAndScaling(source.getType(),
1066 broadcastSource.getType());
1067 SplatElementsAttr splatConst;
1068 return matchPattern(val, m_Constant(&splatConst));
1072 "not all operands are constants or broadcasts from the same type");
1076 SmallVector<Value> srcValues;
1079 SplatElementsAttr splatConst;
1083 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1084 if (
auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1089 Operation *newConstOp =
1090 operand.getDefiningOp()->getDialect()->materializeConstant(
1091 rewriter, newConst, newType, operand.getLoc());
1092 srcValues.push_back(newConstOp->
getResult(0));
1094 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1099 Operation *elementwiseOp =
1101 unbroadcastResultType, op->
getAttrs());
1105 op, resultType, elementwiseOp->
getResults());
1127class ExtractOpFromElementwise final
1132 LogicalResult matchAndRewrite(vector::ExtractOp op,
1133 PatternRewriter &rewriter)
const override {
1134 Operation *eltwise = op.getSource().getDefiningOp();
1139 isa<vector::FMAOp>(eltwise))
1153 if (!op.getDynamicPosition().empty())
1155 op,
"dynamic position not yet implemented");
1157 Type dstType = op.getType();
1159 OpBuilder::InsertionGuard g(rewriter);
1163 Location loc = eltwise->
getLoc();
1164 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1166 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1167 mapping.
map(arg, newArg);
1170 Operation *newEltwise = rewriter.
clone(*eltwise, mapping);
1181static bool isSupportedMemSinkElementType(
Type type) {
1182 if (isa<IndexType>(type))
1203class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1208 PatternRewriter &rewriter)
const override {
1209 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1214 if (!loadOp->hasOneUse())
1217 VectorType loadVecType = loadOp.getVectorType();
1218 if (loadVecType.isScalable())
1220 "scalable vectors are not supported");
1222 MemRefType memType = loadOp.getMemRefType();
1226 if (!isSupportedMemSinkElementType(memType.getElementType()))
1229 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1233 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1234 int64_t finalRank = 0;
1236 finalRank = extractVecType.getRank();
1238 SmallVector<Value>
indices = loadOp.getIndices();
1239 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1244 OpBuilder::InsertionGuard g(rewriter);
1246 Location loc = loadOp.getLoc();
1247 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1248 for (
auto i : llvm::seq<int64_t>(rankOffset,
indices.size() - finalRank)) {
1249 OpFoldResult pos = extractPos[i - rankOffset];
1257 Value base = loadOp.getBase();
1258 if (extractVecType) {
1281class StoreOpFromBroadcast final :
public OpRewritePattern<vector::StoreOp> {
1286 PatternRewriter &rewriter)
const override {
1287 VectorType vecType = op.getVectorType();
1288 if (vecType.isScalable())
1290 "scalable vectors are not supported");
1292 if (isa<VectorType>(op.getMemRefType().getElementType()))
1294 op,
"memrefs of vectors are not supported");
1296 if (vecType.getNumElements() != 1)
1298 op,
"only 1-element vectors are supported");
1300 Value toStore = op.getValueToStore();
1301 Value source = getBroadcastLikeSource(toStore);
1304 op,
"value to store is not from a broadcast");
1311 Value base = op.getBase();
1314 if (isa<VectorType>(source.
getType())) {
1334 bool force32BitVectorIndices,
int64_t dim,
1343 if (dim == 0 && force32BitVectorIndices) {
1346 }
else if (dim == 0) {
1349 }
else if (force32BitVectorIndices) {
1351 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1354 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1356 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1360 Value ov = vector::BroadcastOp::create(rewriter, loc,
indices.getType(), o);
1366 vector::BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1367 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1371template <
typename ConcreteOp>
1374 explicit MaterializeTransferMask(MLIRContext *context,
bool enableIndexOpt,
1375 PatternBenefit benefit = 1)
1376 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1377 force32BitVectorIndices(enableIndexOpt) {}
1380 PatternRewriter &rewriter)
const override {
1381 if (!xferOp.hasOutOfBoundsDim())
1384 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1387 Location loc = xferOp->getLoc();
1388 VectorType vtp = xferOp.getVectorType();
1395 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1396 Value off = xferOp.getIndices()[lastIndex];
1399 Value
b = arith::SubIOp::create(rewriter, loc, dim.
getType(), dim, off);
1400 Value mask = vector::CreateMaskOp::create(
1402 VectorType::get(vtp.getShape(), rewriter.
getI1Type(),
1403 vtp.getScalableDims()),
1405 if (xferOp.getMask()) {
1407 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1411 xferOp.getMaskMutable().assign(mask);
1419 const bool force32BitVectorIndices;
1423class VectorCreateMaskOpConversion
1426 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1427 bool enableIndexOpt,
1428 PatternBenefit benefit = 1)
1429 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1430 force32BitVectorIndices(enableIndexOpt) {}
1432 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1433 PatternRewriter &rewriter)
const override {
1434 auto dstType = op.getType();
1435 if (cast<VectorType>(dstType).isScalable())
1437 int64_t rank = dstType.getRank();
1441 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1442 rank == 0 ? 0 : dstType.getDimSize(0),
1448 const bool force32BitVectorIndices;
1452static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1453 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1458 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1459 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1477 PatternRewriter &rewriter)
const override {
1478 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1479 if (!vecType || !vecType.getElementType().isInteger(1))
1483 Value cond = selectOp.getCondition();
1484 if (isa<VectorType>(cond.
getType()))
1488 if (vecType.getRank() != 1 || vecType.isScalable())
1492 if (vecType.getShape()[0] != 1)
1495 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1496 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1500 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1501 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1505 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1506 auto bcastType = VectorType::get({1}, elemType);
1527static FailureOr<size_t>
1531 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1534 auto isUnitDim = [](VectorType type,
int dim) {
1535 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1542 int rankDiff = srcType.getRank() - vectorType.getRank();
1543 for (
int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1546 int dim = vectorType.getRank() - i - 1;
1547 if (srcStrides[dim + rankDiff] != 1 ||
1548 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1560 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1563 if (readOp.getTransferRank() == 0)
1567 if (readOp.getMask())
1570 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1574 if (!readOp.getPermutationMap().isMinorIdentity())
1577 auto targetType = readOp.getVectorType();
1578 if (targetType.getRank() <= 1)
1581 FailureOr<size_t> maybeDimsToDrop =
1583 if (failed(maybeDimsToDrop))
1586 size_t dimsToDrop = maybeDimsToDrop.value();
1587 if (dimsToDrop == 0)
1590 auto inBounds = readOp.getInBoundsValues();
1591 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1592 if (llvm::is_contained(droppedInBounds,
false))
1595 auto resultTargetVecType =
1596 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1597 targetType.getElementType(),
1598 targetType.getScalableDims().drop_back(dimsToDrop));
1600 auto loc = readOp.getLoc();
1607 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1608 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1611 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1612 Value rankedReducedView =
1613 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1614 readOp.getBase(), offsets, sizes, strides);
1616 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1618 rewriter, loc, resultTargetVecType, rankedReducedView,
1619 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1620 readOp.getPadding(),
1622 Value(), inBoundsAttr);
1651 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1654 if (writeOp.getTransferRank() == 0)
1658 if (writeOp.getMask())
1661 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1665 if (!writeOp.getPermutationMap().isMinorIdentity())
1668 auto targetType = writeOp.getVectorType();
1669 if (targetType.getRank() <= 1)
1672 FailureOr<size_t> maybeDimsToDrop =
1674 if (failed(maybeDimsToDrop))
1677 size_t dimsToDrop = maybeDimsToDrop.value();
1678 if (dimsToDrop == 0)
1681 auto inBounds = writeOp.getInBoundsValues();
1682 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1683 if (llvm::is_contained(droppedInBounds,
false))
1686 auto resultTargetVecType =
1687 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1688 targetType.getElementType(),
1689 targetType.getScalableDims().drop_back(dimsToDrop));
1698 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1699 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1702 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1704 Value rankedReducedView =
1705 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1706 writeOp.getBase(), offsets, sizes, strides);
1708 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1710 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1711 loc, resultTargetVecType, writeOp.getVector());
1713 writeOp, shapeCast, rankedReducedView,
1714 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1716 Value(), inBoundsAttr);
1729 std::function<LogicalResult(vector::ContractionOp op)>;
1734 filter(std::move(constraint)) {}
1738 if (failed(filter(op)))
1744 Value res = op.getAcc();
1748 auto infer = [&](MapList m) {
1755 static constexpr std::array<int64_t, 2> perm = {1, 0};
1756 auto iteratorTypes = op.getIteratorTypes().getValue();
1758 if (iteratorTypes.size() != 3 ||
1765 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1766 if (maps == canonicalForm)
1771 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1772 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1774 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1775 VectorType newType =
1776 cast<VectorType>(trans.
getType())
1777 .clone(cast<VectorType>(mat.getType()).getElementType());
1778 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1780 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1782 vector::TransposeOp::create(rewriter, loc,
zext.getIn(), perm);
1783 VectorType newType =
1784 VectorType::get(cast<VectorType>(trans.
getType()).getShape(),
1785 cast<VectorType>(mat.getType()).getElementType());
1786 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1788 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1791 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1792 rhs = createTranspose(
rhs);
1793 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1794 lhs = createTranspose(
lhs);
1795 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1796 rhs = createTranspose(
rhs);
1797 lhs = createTranspose(
lhs);
1798 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1800 rhs = createTranspose(
rhs);
1801 lhs = createTranspose(
lhs);
1802 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1804 rhs = createTranspose(
rhs);
1805 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1807 lhs = createTranspose(
lhs);
1808 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1815 op.getIteratorTypes());
1840template <
typename ExtOp>
1848 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1849 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1851 if (!lhsDefOp || !rhsDefOp) {
1853 "no defining op on contract operands");
1857 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1858 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1859 contractOp.getIteratorTypesAttr());
1881 if (op.getKind() != vector::CombiningKind::ADD)
1889 if (!
acc.getType().isIntOrFloat())
1892 auto parentReduction =
acc.getDefiningOp<vector::ReductionOp>();
1893 if (!parentReduction)
1898 if (isa<IntegerType>(
acc.getType())) {
1900 loc, parentReduction.getVector(), op.getVector());
1902 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1906 parentReduction.getAcc());
1917 auto inVecShape = inVecTy.getShape();
1920 for (
auto [dim, isScalable] :
1921 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1922 if (dim == 1 && !isScalable)
1925 newShape.push_back(dim);
1926 newScalableDims.push_back(isScalable);
1929 if (newShape.empty()) {
1930 newShape.push_back(1);
1931 newScalableDims.push_back(
false);
1934 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1971 if (!resultVectorType)
1978 if (!sourceVectorType)
1980 if (sourceVectorType.getRank() < 2)
1986 auto opVectorType = cast<VectorType>(operand.getType());
1988 if (newVType == opVectorType)
1991 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
1992 newOperands.push_back(opSC);
1995 VectorType newResultVectorType =
2000 newResultVectorType, op->
getAttrs());
2035 VectorType sourceType = op.getSourceVectorType();
2036 VectorType sourceTypeWithoutUnitDims =
2039 if (sourceType == sourceTypeWithoutUnitDims)
2046 for (
auto [i, dim] : llvm::enumerate(sourceDims)) {
2047 droppedDimsBefore[i] = droppedDims;
2048 if (dim == std::make_tuple(1,
false))
2056 if (sourceDims[idx] == std::make_tuple(1,
false))
2058 newPerm.push_back(idx - droppedDimsBefore[idx]);
2064 if (newPerm.empty()) {
2065 newPerm.push_back(0);
2070 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2071 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2073 auto transposeWithoutUnitDims =
2074 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2077 op, op.getResultVectorType(), transposeWithoutUnitDims);
2114 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
2115 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2120 if (vectorType == newVectorType)
2125 return vector::ShapeCastOp::create(
b, loc, type, source);
2129 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2131 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2158 if (op.getKind() != vector::CombiningKind::ADD)
2161 Type elemType = op.getSourceVectorType().getElementType();
2164 if (!isa<FloatType>(elemType))
2167 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2170 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2177 auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2178 addLhs.getLhs(), vAdd.getRhs());
2197 unsigned maxNumElementsToExtract,
2200 maxNumElementsToExtract(maxNumElementsToExtract) {}
2204 VectorType type = op.getSourceVectorType();
2205 if (type.isScalable() || op.isMasked())
2207 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2209 int64_t numElems = type.getNumElements();
2210 if (numElems > maxNumElementsToExtract) {
2212 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2213 "(max allowed: {1})",
2214 numElems, maxNumElementsToExtract));
2219 for (
auto [idx, extractedElem] : llvm::enumerate(extracted))
2220 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2223 Value res = extracted.front();
2224 for (
auto extractedElem : llvm::drop_begin(extracted))
2226 extractedElem, op.getFastmathAttr());
2229 op.getFastmathAttr());
2236 unsigned maxNumElementsToExtract = 0;
2255template <
typename MulOpType>
2260 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2263 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2266 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2267 return srcType && srcType.getRank() != 2;
2272 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2275 if (resType.getRank() != 2)
2280 auto matchOuterProduct =
2282 Value operandB) -> FailureOr<vector::OuterProductOp> {
2283 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2288 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2291 auto broadcastedLhs =
2292 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2293 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2296 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2297 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2300 return vector::OuterProductOp::create(
2301 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2302 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2305 Value lhs = mulOp->getOperand(0),
rhs = mulOp->getOperand(1);
2306 auto maybeOuterP = matchOuterProduct(
lhs,
rhs);
2308 if (failed(maybeOuterP))
2309 maybeOuterP = matchOuterProduct(
rhs,
lhs);
2310 if (failed(maybeOuterP))
2312 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2326void mlir::vector::populateVectorMaskMaterializationPatterns(
2329 patterns.add<VectorCreateMaskOpConversion,
2330 MaterializeTransferMask<vector::TransferReadOp>,
2331 MaterializeTransferMask<vector::TransferWriteOp>>(
2332 patterns.getContext(), force32BitVectorIndices, benefit);
2336void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2342void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2344 patterns.add<BubbleDownVectorBitCastForExtract,
2345 BubbleDownBitCastForStridedSliceExtract,
2346 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2350void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2352 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2354 std::move(controlFn), benefit);
2359 std::function<LogicalResult(vector::ContractionOp)> constraint,
2362 std::move(constraint));
2367 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2368 CombineContractABTranspose, CombineContractResultTranspose>(
2381 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2382 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2389 patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(
patterns.getContext(),
2393void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2400void mlir::vector::populateBreakDownVectorReductionPatterns(
2404 maxNumElementsToExtract, benefit);
2409 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2410 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2418#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Drop inner most contiguous unit dimensions from transfer_read operand.
Drop inner most contiguous unit dimensions from transfer_write operand. E.g., vector....
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
BreakDownVectorReduction(MLIRContext *context, unsigned maxNumElementsToExtract, PatternBenefit benefit)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction sui...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, FilterConstraintType constraint)
Pattern to fold chained reduction to a series of vector additions and a final reduction....
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b =...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
A pattern to drop unit dims from the iter_args of an scf.for.
LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override
A pattern to drop unit dims from vector.transpose.
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Pattern to fold arithmetic extensions on floating point data types into vector contraction operations...
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
Pattern to eliminate redundant zero-constants added to reduction operands. It's enough for there to b...
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.