30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32#include "llvm/Support/FormatVariadic.h"
39#define DEBUG_TYPE "vector-to-vector"
44template <
typename IntType>
46 return llvm::to_vector<4>(llvm::map_range(
47 arrayAttr.getAsRange<IntegerAttr>(),
48 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
80struct MultiReduceToContract
84 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
85 PatternRewriter &rewriter)
const override {
86 if (reduceOp.getKind() != vector::CombiningKind::ADD)
88 Operation *mulOp = reduceOp.getSource().getDefiningOp();
89 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
91 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
93 SmallVector<AffineExpr> exprs;
94 SmallVector<vector::IteratorType> iteratorTypes;
95 for (
const auto &isReduceDim : llvm::enumerate(reductionMask)) {
96 if (!isReduceDim.value()) {
97 iteratorTypes.push_back(vector::IteratorType::parallel);
100 iteratorTypes.push_back(vector::IteratorType::reduction);
105 0, exprs, reduceOp.getContext());
110 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
111 return IteratorTypeAttr::get(rewriter.getContext(), t);
140struct CombineContractABTranspose final
144 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
145 PatternRewriter &rewriter)
const override {
146 SmallVector<AffineMap> maps =
147 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
148 Value
lhs = contractOp.getLhs();
149 Value
rhs = contractOp.getRhs();
151 bool changed =
false;
152 for (Value *operand : {&
lhs, &
rhs}) {
153 AffineMap &map = maps[index++];
154 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
158 transposeOp.getPermutation(), contractOp.getContext());
160 *operand = transposeOp.getVector();
166 contractOp,
lhs,
rhs, contractOp.getAcc(),
204struct CombineContractResultTranspose final
208 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
209 PatternRewriter &rewriter)
const override {
210 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
211 if (!contractOp || !contractOp->hasOneUse())
214 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
218 MLIRContext *context = contractOp.getContext();
219 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
220 AffineMap contractMap = maps.back();
231 auto combinedResMap = resTMap.compose(contractMap);
238 maps.back() = combinedResMap;
241 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
274FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
275 MaskingOpInterface maskingOp,
278 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
282 bool changed =
false;
285 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
289 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
291 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
294 broadcast.getResultVectorType().getRank() - srcType.getRank();
295 bool innerDimBroadcast =
false;
297 for (
const auto &dim : llvm::enumerate(srcType.getShape())) {
299 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
300 innerDimBroadcast =
true;
307 if (innerDimBroadcast)
312 bool nonUnitDimReductionBroadcast =
false;
313 for (
int64_t i = 0; i < rankDiff; ++i) {
314 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
317 nonUnitDimReductionBroadcast =
true;
321 if (nonUnitDimReductionBroadcast)
326 originalDims, contractOp.getContext());
327 map = broadcastMap.
compose(map);
343 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
344 if (!unusedDimsBitVector.test(i))
345 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
351 VectorType oldMaskType;
352 bool isAnyUnusedDimNonUnit =
false;
354 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
355 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
356 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
357 isAnyUnusedDimNonUnit =
true;
368 bool hasReductionIteratorApplyingOnBothSides =
false;
369 for (
unsigned i = 0; i < iterators.size(); ++i) {
373 hasReductionIteratorApplyingOnBothSides =
true;
377 if (!hasReductionIteratorApplyingOnBothSides)
385 Operation *newOp = vector::ContractionOp::create(
386 rewriter, contractOp.getLoc(),
lhs,
rhs, contractOp.getAcc(),
391 if (isAnyUnusedDimNonUnit)
393 "Cannont drop non-unit mask dim.");
394 assert(unusedDimsBitVector.size() ==
395 static_cast<size_t>(oldMaskType.getRank()) &&
396 "The mask rank is incorrect!");
400 Value mask = maskingOp.getMask();
401 if (unusedDimsBitVector.count() != 0) {
409 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
410 auto newShapeScalableDims =
411 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
412 VectorType maskOpType =
413 VectorType::get(newShape, rewriter.
getI1Type(), newShapeScalableDims);
414 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
415 maskOpType, maskingOp.getMask())
424struct CombineContractBroadcastMask
426 using MaskableOpRewritePattern::MaskableOpRewritePattern;
429 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
430 MaskingOpInterface maskingOp,
431 PatternRewriter &rewriter)
const override {
432 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
449struct ReorderCastOpsOnBroadcast
451 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
453 LogicalResult matchAndRewrite(CastOpInterface op,
454 PatternRewriter &rewriter)
const override {
455 if (op->getNumOperands() != 1)
457 if (!isa<VectorType>(op->getResult(0).getType()))
459 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
464 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
465 castResTy = vecTy.clone(castResTy);
467 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
468 bcastOp.getSource(), castResTy, op->getAttrs());
470 op, op->getResult(0).
getType(), castOp->getResult(0));
489struct ReorderElementwiseOpsOnTranspose final
492 LogicalResult matchAndRewrite(Operation *op,
493 PatternRewriter &rewriter)
const override {
499 SmallVector<ArrayRef<int64_t>> transposeMaps;
505 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
507 transposeMaps.push_back(transposeOp.getPermutation());
508 srcType = transposeOp.getSourceVectorType();
513 if (transposeMaps.empty())
518 if (!llvm::all_equal(transposeMaps))
521 SmallVector<Value> srcValues;
526 auto order = transposeMaps.front();
527 SmallVector<int64_t> invOrder(order.size());
528 for (
int i = 0, e = order.size(); i < e; ++i)
529 invOrder[order[i]] = i;
532 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
534 srcValues.push_back(transposeOp.getVector());
538 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
539 srcValues.push_back(vector::TransposeOp::create(
540 rewriter, operand.getLoc(), vectorType, operand, invOrder));
544 auto vectorType = srcType.clone(
546 Operation *elementwiseOp =
551 transposeMaps.front());
558 return llvm::map_to_vector<4>(arrayAttr.getAsRange<IntegerAttr>(),
559 [](IntegerAttr attr) { return attr.getInt(); });
571struct BubbleDownVectorBitCastForExtract
575 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
576 PatternRewriter &rewriter)
const override {
578 if (extractOp.getSourceVectorType().getRank() != 1)
581 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
585 VectorType castSrcType = castOp.getSourceVectorType();
586 VectorType castDstType = castOp.getResultVectorType();
587 assert(castSrcType.getRank() == castDstType.getRank());
592 if (castSrcType.getNumElements() == 1)
597 if (castSrcType.getNumElements() > castDstType.getNumElements())
600 unsigned expandRatio =
601 castDstType.getNumElements() / castSrcType.getNumElements();
604 auto mixedPos = extractOp.getMixedPosition();
605 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
607 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
611 Location loc = extractOp.getLoc();
612 Value packedValue = vector::ExtractOp::create(
613 rewriter, loc, castOp.getSource(), index / expandRatio);
614 Type packedVecType = VectorType::get({1}, packedValue.
getType());
615 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
617 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
622 VectorType packedType =
623 VectorType::get({expandRatio}, castDstType.getElementType());
625 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
629 index % expandRatio);
646struct BubbleDownBitCastForStridedSliceExtract
650 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
651 PatternRewriter &rewriter)
const override {
652 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
656 VectorType castSrcType = castOp.getSourceVectorType();
657 VectorType castDstType = castOp.getResultVectorType();
658 assert(castSrcType.getRank() == castDstType.getRank());
660 int64_t castSrcLastDim = castSrcType.getShape().back();
661 int64_t castDstLastDim = castDstType.getShape().back();
663 if (castSrcLastDim > castDstLastDim)
667 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
668 [](
const APInt &val) { return !val.isOne(); }))
671 unsigned rank = extractOp.getSourceVectorType().getRank();
672 assert(castDstLastDim % castSrcLastDim == 0);
673 int64_t expandRatio = castDstLastDim / castSrcLastDim;
679 ArrayAttr newOffsets = extractOp.getOffsets();
680 if (newOffsets.size() == rank) {
681 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
682 if (offsets.back() % expandRatio != 0)
684 offsets.back() = offsets.back() / expandRatio;
689 ArrayAttr newSizes = extractOp.getSizes();
690 if (newSizes.size() == rank) {
691 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
692 if (sizes.back() % expandRatio != 0)
694 sizes.back() = sizes.back() / expandRatio;
698 SmallVector<int64_t> dims =
699 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
700 dims.back() = dims.back() / expandRatio;
701 VectorType newExtractType =
702 VectorType::get(dims, castSrcType.getElementType());
704 auto newExtractOp = vector::ExtractStridedSliceOp::create(
705 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
706 newOffsets, newSizes, extractOp.getStrides());
709 extractOp, extractOp.getType(), newExtractOp);
725struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
728 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
729 PatternRewriter &rewriter)
const override {
730 VectorType castSrcType = bitcastOp.getSourceVectorType();
731 VectorType castDstType = bitcastOp.getResultVectorType();
734 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
735 castDstType.isScalable())
738 int64_t castSrcLastDim = castSrcType.getShape().back();
739 int64_t castDstLastDim = castDstType.getShape().back();
740 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
742 if (isNumElemsShrink) {
743 assert(castSrcLastDim % castDstLastDim == 0);
744 ratio = castSrcLastDim / castDstLastDim;
746 assert(castDstLastDim % castSrcLastDim == 0);
747 ratio = castDstLastDim / castSrcLastDim;
750 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
755 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
760 SmallVector<int64_t> srcDims(insertSrcType.getShape());
762 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
763 VectorType newCastSrcType =
764 VectorType::get(srcDims, castDstType.getElementType());
766 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
767 insertOp.getValueToStore());
769 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
771 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
772 VectorType newCastDstType =
773 VectorType::get(dstDims, castDstType.getElementType());
776 auto newCastDstOp = vector::BitCastOp::create(
777 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
781 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
797struct BubbleUpBitCastForStridedSliceInsert
801 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
802 PatternRewriter &rewriter)
const override {
803 VectorType castSrcType = bitcastOp.getSourceVectorType();
804 VectorType castDstType = bitcastOp.getResultVectorType();
805 assert(castSrcType.getRank() == castDstType.getRank());
807 if (castSrcType.getRank() == 0)
810 int64_t castSrcLastDim = castSrcType.getShape().back();
811 int64_t castDstLastDim = castDstType.getShape().back();
813 if (castSrcLastDim < castDstLastDim)
816 assert(castSrcLastDim % castDstLastDim == 0);
817 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
820 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
825 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
826 [](
const APInt &val) { return !val.isOne(); }))
829 unsigned rank = insertOp.getSourceVectorType().getRank();
832 if (rank != insertOp.getDestVectorType().getRank())
836 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
837 unsigned destinationWidth =
838 castDstType.getElementType().getIntOrFloatBitWidth();
839 unsigned numElements = destinationWidth / sourceWidth;
840 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
843 ArrayAttr newOffsets = insertOp.getOffsets();
844 assert(newOffsets.size() == rank);
845 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
846 if (offsets.back() % shrinkRatio != 0)
848 offsets.back() = offsets.back() / shrinkRatio;
851 SmallVector<int64_t> srcDims =
852 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
853 srcDims.back() = srcDims.back() / shrinkRatio;
854 VectorType newCastSrcType =
855 VectorType::get(srcDims, castDstType.getElementType());
858 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
859 insertOp.getValueToStore());
861 SmallVector<int64_t> dstDims =
862 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
863 dstDims.back() = dstDims.back() / shrinkRatio;
864 VectorType newCastDstType =
865 VectorType::get(dstDims, castDstType.getElementType());
867 auto newCastDstOp = vector::BitCastOp::create(
868 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
871 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
872 insertOp.getStrides());
900 BreakDownVectorBitCast(MLIRContext *context,
901 std::function<
bool(vector::BitCastOp)> controlFn,
902 PatternBenefit benefit)
903 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
905 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
906 PatternRewriter &rewriter)
const override {
908 if (controlFn && !controlFn(bitcastOp))
911 VectorType castSrcType = bitcastOp.getSourceVectorType();
912 VectorType castDstType = bitcastOp.getResultVectorType();
913 assert(castSrcType.getRank() == castDstType.getRank());
918 if (castSrcType.isScalable())
920 "Scalable vectors are not supported");
923 if (castSrcType.getRank() != 1)
926 int64_t castSrcLastDim = castSrcType.getShape().back();
927 int64_t castDstLastDim = castDstType.getShape().back();
929 if (castSrcLastDim < castDstLastDim)
932 assert(castSrcLastDim % castDstLastDim == 0);
933 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
935 if (castSrcLastDim == shrinkRatio)
938 Location loc = bitcastOp.getLoc();
939 Type elemType = castDstType.getElementType();
942 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
944 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
946 SmallVector<int64_t> sliceShape = {castDstLastDim};
947 SmallVector<int64_t> strides = {1};
948 VectorType newCastDstType =
949 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
950 castDstType.getElementType());
952 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
953 Value extracted = ExtractStridedSliceOp::create(
954 rewriter, loc, bitcastOp.getSource(),
955 ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
957 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
958 res = InsertStridedSliceOp::create(
959 rewriter, loc, bitcast, res,
960 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
967 std::function<bool(BitCastOp)> controlFn;
970static bool haveSameShapeAndScaling(
Type t,
Type u) {
971 auto tVec = dyn_cast<VectorType>(t);
972 auto uVec = dyn_cast<VectorType>(u);
979 return tVec.getShape() == uVec.getShape() &&
980 tVec.getScalableDims() == uVec.getScalableDims();
985static Type cloneOrReplace(
Type type,
Type newElementType) {
986 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
987 return shapedType.clone(newElementType);
989 return newElementType;
994static Value getBroadcastLikeSource(
Value value) {
1000 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1019struct ReorderElementwiseOpsOnBroadcast final
1022 LogicalResult matchAndRewrite(Operation *op,
1023 PatternRewriter &rewriter)
const override {
1031 op,
"Op doesn't have ElementwiseMappableTraits");
1034 if (isa<vector::FMAOp>(op)) {
1037 "Op only accepts vector types - not supported as broadcast source "
1038 "might be a scalar");
1041 Type resultElemType = resultType.getElementType();
1044 Value broadcastSource;
1046 Operation *definingOp = operand.getDefiningOp();
1049 if (definingOp->
hasTrait<OpTrait::ConstantLike>())
1051 broadcastSource = getBroadcastLikeSource(operand);
1054 if (!broadcastSource)
1056 Type unbroadcastResultType =
1057 cloneOrReplace(broadcastSource.
getType(), resultElemType);
1063 if (!llvm::all_of(op->
getOperands(), [broadcastSource](Value val) {
1064 if (auto source = getBroadcastLikeSource(val))
1065 return haveSameShapeAndScaling(source.getType(),
1066 broadcastSource.getType());
1067 SplatElementsAttr splatConst;
1068 return matchPattern(val, m_Constant(&splatConst));
1072 "not all operands are constants or broadcasts from the same type");
1076 SmallVector<Value> srcValues;
1079 SplatElementsAttr splatConst;
1083 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1084 if (
auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1089 Operation *newConstOp =
1090 operand.getDefiningOp()->getDialect()->materializeConstant(
1091 rewriter, newConst, newType, operand.getLoc());
1092 srcValues.push_back(newConstOp->
getResult(0));
1094 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1099 Operation *elementwiseOp =
1101 unbroadcastResultType, op->
getAttrs());
1105 op, resultType, elementwiseOp->
getResults());
1127class ExtractOpFromElementwise final
1132 LogicalResult matchAndRewrite(vector::ExtractOp op,
1133 PatternRewriter &rewriter)
const override {
1134 Operation *eltwise = op.getSource().getDefiningOp();
1139 isa<vector::FMAOp>(eltwise))
1153 if (!op.getDynamicPosition().empty())
1155 op,
"dynamic position not yet implemented");
1157 Type dstType = op.getType();
1159 OpBuilder::InsertionGuard g(rewriter);
1163 Location loc = eltwise->
getLoc();
1164 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1166 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1167 mapping.
map(arg, newArg);
1170 Operation *newEltwise = rewriter.
clone(*eltwise, mapping);
1181static bool isSupportedMemSinkElementType(
Type type) {
1182 if (isa<IndexType>(type))
1203class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1208 PatternRewriter &rewriter)
const override {
1209 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1214 if (!loadOp->hasOneUse())
1217 VectorType loadVecType = loadOp.getVectorType();
1218 if (loadVecType.isScalable())
1220 "scalable vectors are not supported");
1222 MemRefType memType = loadOp.getMemRefType();
1226 if (!isSupportedMemSinkElementType(memType.getElementType()))
1229 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1233 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1234 int64_t finalRank = 0;
1236 finalRank = extractVecType.getRank();
1238 SmallVector<Value>
indices = loadOp.getIndices();
1239 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1244 OpBuilder::InsertionGuard g(rewriter);
1246 Location loc = loadOp.getLoc();
1247 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1248 for (
auto i : llvm::seq<int64_t>(rankOffset,
indices.size() - finalRank)) {
1249 OpFoldResult pos = extractPos[i - rankOffset];
1257 Value base = loadOp.getBase();
1258 if (extractVecType) {
1281class StoreOpFromBroadcast final :
public OpRewritePattern<vector::StoreOp> {
1286 PatternRewriter &rewriter)
const override {
1287 VectorType vecType = op.getVectorType();
1288 if (vecType.isScalable())
1290 "scalable vectors are not supported");
1292 if (isa<VectorType>(op.getMemRefType().getElementType()))
1294 op,
"memrefs of vectors are not supported");
1296 if (vecType.getNumElements() != 1)
1298 op,
"only 1-element vectors are supported");
1300 Value toStore = op.getValueToStore();
1301 Value source = getBroadcastLikeSource(toStore);
1304 op,
"value to store is not from a broadcast");
1311 Value base = op.getBase();
1314 if (isa<VectorType>(source.
getType())) {
1334 bool force32BitVectorIndices,
int64_t dim,
1343 if (dim == 0 && force32BitVectorIndices) {
1346 }
else if (dim == 0) {
1349 }
else if (force32BitVectorIndices) {
1351 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1354 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1356 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1360 Value ov = vector::BroadcastOp::create(rewriter, loc,
indices.getType(), o);
1366 vector::BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1367 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1371template <
typename ConcreteOp>
1374 explicit MaterializeTransferMask(MLIRContext *context,
bool enableIndexOpt,
1375 PatternBenefit benefit = 1)
1376 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1377 force32BitVectorIndices(enableIndexOpt) {}
1380 PatternRewriter &rewriter)
const override {
1381 if (!xferOp.hasOutOfBoundsDim())
1384 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1387 Location loc = xferOp->getLoc();
1388 VectorType vtp = xferOp.getVectorType();
1395 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1396 Value off = xferOp.getIndices()[lastIndex];
1399 Value
b = arith::SubIOp::create(rewriter, loc, dim.
getType(), dim, off);
1400 Value mask = vector::CreateMaskOp::create(
1402 VectorType::get(vtp.getShape(), rewriter.
getI1Type(),
1403 vtp.getScalableDims()),
1405 if (xferOp.getMask()) {
1407 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1411 xferOp.getMaskMutable().assign(mask);
1419 const bool force32BitVectorIndices;
1423class VectorCreateMaskOpConversion
1426 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1427 bool enableIndexOpt,
1428 PatternBenefit benefit = 1)
1429 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1430 force32BitVectorIndices(enableIndexOpt) {}
1432 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1433 PatternRewriter &rewriter)
const override {
1434 auto dstType = op.getType();
1435 if (cast<VectorType>(dstType).isScalable())
1437 int64_t rank = dstType.getRank();
1441 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1442 rank == 0 ? 0 : dstType.getDimSize(0),
1448 const bool force32BitVectorIndices;
1452static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1453 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1458 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1459 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1477 PatternRewriter &rewriter)
const override {
1478 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1479 if (!vecType || !vecType.getElementType().isInteger(1))
1483 Value cond = selectOp.getCondition();
1484 if (isa<VectorType>(cond.
getType()))
1488 if (vecType.getRank() != 1 || vecType.isScalable())
1492 if (vecType.getShape()[0] != 1)
1495 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1496 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1500 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1501 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1505 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1506 auto bcastType = VectorType::get({1}, elemType);
1527static FailureOr<size_t>
1531 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1534 auto isUnitDim = [](VectorType type,
int dim) {
1535 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1542 int rankDiff = srcType.getRank() - vectorType.getRank();
1543 for (
int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1546 int dim = vectorType.getRank() - i - 1;
1547 if (srcStrides[dim + rankDiff] != 1 ||
1548 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1560 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1563 if (readOp.getTransferRank() == 0)
1566 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1570 if (!readOp.getPermutationMap().isMinorIdentity())
1573 auto targetType = readOp.getVectorType();
1574 if (targetType.getRank() <= 1)
1577 FailureOr<size_t> maybeDimsToDrop =
1579 if (failed(maybeDimsToDrop))
1582 size_t dimsToDrop = maybeDimsToDrop.value();
1583 if (dimsToDrop == 0)
1586 auto inBounds = readOp.getInBoundsValues();
1587 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1588 if (llvm::is_contained(droppedInBounds,
false))
1591 auto resultTargetVecType =
1592 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1593 targetType.getElementType(),
1594 targetType.getScalableDims().drop_back(dimsToDrop));
1596 auto loc = readOp.getLoc();
1603 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1604 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1607 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1608 Value rankedReducedView =
1609 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1610 readOp.getBase(), offsets, sizes, strides);
1612 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1615 Value mask = readOp.getMask();
1617 auto maskType = cast<VectorType>(mask.
getType());
1618 auto reducedMaskType = VectorType::get(
1619 maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
1620 maskType.getScalableDims().drop_back(dimsToDrop));
1621 mask = rewriter.
createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
1626 rewriter, loc, resultTargetVecType, rankedReducedView,
1627 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1628 readOp.getPadding(), mask, inBoundsAttr);
1657 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1660 if (writeOp.getTransferRank() == 0)
1663 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1667 if (!writeOp.getPermutationMap().isMinorIdentity())
1670 auto targetType = writeOp.getVectorType();
1671 if (targetType.getRank() <= 1)
1674 FailureOr<size_t> maybeDimsToDrop =
1676 if (failed(maybeDimsToDrop))
1679 size_t dimsToDrop = maybeDimsToDrop.value();
1680 if (dimsToDrop == 0)
1683 auto inBounds = writeOp.getInBoundsValues();
1684 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1685 if (llvm::is_contained(droppedInBounds,
false))
1688 auto resultTargetVecType =
1689 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1690 targetType.getElementType(),
1691 targetType.getScalableDims().drop_back(dimsToDrop));
1700 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1701 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1704 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1706 Value rankedReducedView =
1707 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1708 writeOp.getBase(), offsets, sizes, strides);
1710 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1712 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1713 loc, resultTargetVecType, writeOp.getVector());
1716 Value mask = writeOp.getMask();
1718 auto maskType = cast<VectorType>(mask.
getType());
1719 auto reducedMaskType = VectorType::get(
1720 maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
1721 maskType.getScalableDims().drop_back(dimsToDrop));
1722 mask = rewriter.
createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
1727 writeOp, shapeCast, rankedReducedView,
1728 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1729 mask, inBoundsAttr);
1742 std::function<LogicalResult(vector::ContractionOp op)>;
1747 filter(std::move(constraint)) {}
1751 if (failed(filter(op)))
1757 Value res = op.getAcc();
1761 auto infer = [&](MapList m) {
1768 static constexpr std::array<int64_t, 2> perm = {1, 0};
1769 auto iteratorTypes = op.getIteratorTypes().getValue();
1771 if (iteratorTypes.size() != 3 ||
1778 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1779 if (maps == canonicalForm)
1784 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1785 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1787 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1788 VectorType newType =
1789 cast<VectorType>(trans.
getType())
1790 .clone(cast<VectorType>(mat.getType()).getElementType());
1791 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1793 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1795 vector::TransposeOp::create(rewriter, loc,
zext.getIn(), perm);
1796 VectorType newType =
1797 VectorType::get(cast<VectorType>(trans.
getType()).getShape(),
1798 cast<VectorType>(mat.getType()).getElementType());
1799 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1801 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1804 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1805 rhs = createTranspose(
rhs);
1806 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1807 lhs = createTranspose(
lhs);
1808 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1809 rhs = createTranspose(
rhs);
1810 lhs = createTranspose(
lhs);
1811 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1813 rhs = createTranspose(
rhs);
1814 lhs = createTranspose(
lhs);
1815 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1817 rhs = createTranspose(
rhs);
1818 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1820 lhs = createTranspose(
lhs);
1821 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1828 op.getIteratorTypes());
1853template <
typename ExtOp>
1861 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1862 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1864 if (!lhsDefOp || !rhsDefOp) {
1866 "no defining op on contract operands");
1870 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1871 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1872 contractOp.getIteratorTypesAttr());
1894 if (op.getKind() != vector::CombiningKind::ADD)
1902 if (!
acc.getType().isIntOrFloat())
1905 auto parentReduction =
acc.getDefiningOp<vector::ReductionOp>();
1906 if (!parentReduction)
1911 if (isa<IntegerType>(
acc.getType())) {
1913 loc, parentReduction.getVector(), op.getVector());
1915 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1919 parentReduction.getAcc());
1930 auto inVecShape = inVecTy.getShape();
1933 for (
auto [dim, isScalable] :
1934 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1935 if (dim == 1 && !isScalable)
1938 newShape.push_back(dim);
1939 newScalableDims.push_back(isScalable);
1942 if (newShape.empty()) {
1943 newShape.push_back(1);
1944 newScalableDims.push_back(
false);
1947 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1984 if (!resultVectorType)
1991 if (!sourceVectorType)
1993 if (sourceVectorType.getRank() < 2)
1999 auto opVectorType = cast<VectorType>(operand.getType());
2001 if (newVType == opVectorType)
2004 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
2005 newOperands.push_back(opSC);
2008 VectorType newResultVectorType =
2013 newResultVectorType, op->
getAttrs());
2048 VectorType sourceType = op.getSourceVectorType();
2049 VectorType sourceTypeWithoutUnitDims =
2052 if (sourceType == sourceTypeWithoutUnitDims)
2059 for (
auto [i, dim] : llvm::enumerate(sourceDims)) {
2060 droppedDimsBefore[i] = droppedDims;
2061 if (dim == std::make_tuple(1,
false))
2069 if (sourceDims[idx] == std::make_tuple(1,
false))
2071 newPerm.push_back(idx - droppedDimsBefore[idx]);
2077 if (newPerm.empty()) {
2078 newPerm.push_back(0);
2083 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2084 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2086 auto transposeWithoutUnitDims =
2087 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2090 op, op.getResultVectorType(), transposeWithoutUnitDims);
2127 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
2128 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2133 if (vectorType == newVectorType)
2138 return vector::ShapeCastOp::create(
b, loc, type, source);
2142 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2144 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2171 if (op.getKind() != vector::CombiningKind::ADD)
2174 Type elemType = op.getSourceVectorType().getElementType();
2177 if (!isa<FloatType>(elemType))
2180 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2183 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2190 auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2191 addLhs.getLhs(), vAdd.getRhs());
2210 unsigned maxNumElementsToExtract,
2213 maxNumElementsToExtract(maxNumElementsToExtract) {}
2217 VectorType type = op.getSourceVectorType();
2218 if (type.isScalable() || op.isMasked())
2220 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2222 int64_t numElems = type.getNumElements();
2223 if (numElems > maxNumElementsToExtract) {
2225 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2226 "(max allowed: {1})",
2227 numElems, maxNumElementsToExtract));
2232 for (
auto [idx, extractedElem] : llvm::enumerate(extracted))
2233 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2236 Value res = extracted.front();
2237 for (
auto extractedElem : llvm::drop_begin(extracted))
2239 extractedElem, op.getFastmathAttr());
2242 op.getFastmathAttr());
2249 unsigned maxNumElementsToExtract = 0;
2268template <
typename MulOpType>
2273 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2276 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2279 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2280 return srcType && srcType.getRank() != 2;
2285 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2288 if (resType.getRank() != 2)
2293 auto matchOuterProduct =
2295 Value operandB) -> FailureOr<vector::OuterProductOp> {
2296 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2301 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2304 auto broadcastedLhs =
2305 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2306 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2309 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2310 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2313 return vector::OuterProductOp::create(
2314 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2315 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2318 Value lhs = mulOp->getOperand(0),
rhs = mulOp->getOperand(1);
2319 auto maybeOuterP = matchOuterProduct(
lhs,
rhs);
2321 if (failed(maybeOuterP))
2322 maybeOuterP = matchOuterProduct(
rhs,
lhs);
2323 if (failed(maybeOuterP))
2325 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2339void mlir::vector::populateVectorMaskMaterializationPatterns(
2342 patterns.
add<VectorCreateMaskOpConversion,
2343 MaterializeTransferMask<vector::TransferReadOp>,
2344 MaterializeTransferMask<vector::TransferWriteOp>>(
2345 patterns.
getContext(), force32BitVectorIndices, benefit);
2349void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2355void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2357 patterns.
add<BubbleDownVectorBitCastForExtract,
2358 BubbleDownBitCastForStridedSliceExtract,
2359 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2363void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2365 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2367 std::move(controlFn), benefit);
2372 std::function<LogicalResult(vector::ContractionOp)> constraint,
2375 std::move(constraint));
2380 patterns.
add<MultiReduceToContract, CombineContractBroadcastMask,
2381 CombineContractABTranspose, CombineContractResultTranspose>(
2394 patterns.
add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2395 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2402 patterns.
add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.
getContext(),
2406void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2413void mlir::vector::populateBreakDownVectorReductionPatterns(
2417 maxNumElementsToExtract, benefit);
2422 patterns.
add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2423 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2431#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Drop inner most contiguous unit dimensions from transfer_read operand.
Drop inner most contiguous unit dimensions from transfer_write operand. E.g., vector....
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
BreakDownVectorReduction(MLIRContext *context, unsigned maxNumElementsToExtract, PatternBenefit benefit)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction suitab...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, FilterConstraintType constraint)
Pattern to fold chained reduction to a series of vector additions and a final reduction....
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b =...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
A pattern to drop unit dims from the iter_args of an scf.for.
LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override
A pattern to drop unit dims from vector.transpose.
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Pattern to fold arithmetic extensions on floating point data types into vector contraction operations...
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
Pattern to eliminate redundant zero-constants added to reduction operands. It's enough for there to b...
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.