30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32#include "llvm/Support/FormatVariadic.h"
39#define DEBUG_TYPE "vector-to-vector"
44template <
typename IntType>
46 return llvm::to_vector<4>(llvm::map_range(
47 arrayAttr.getAsRange<IntegerAttr>(),
48 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
80struct MultiReduceToContract
84 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
85 PatternRewriter &rewriter)
const override {
86 if (reduceOp.getKind() != vector::CombiningKind::ADD)
88 Operation *mulOp = reduceOp.getSource().getDefiningOp();
89 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
91 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
93 SmallVector<AffineExpr> exprs;
94 SmallVector<vector::IteratorType> iteratorTypes;
95 for (
const auto &isReduceDim : llvm::enumerate(reductionMask)) {
96 if (!isReduceDim.value()) {
97 iteratorTypes.push_back(vector::IteratorType::parallel);
100 iteratorTypes.push_back(vector::IteratorType::reduction);
105 0, exprs, reduceOp.getContext());
110 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
111 return IteratorTypeAttr::get(rewriter.getContext(), t);
140struct CombineContractABTranspose final
144 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
145 PatternRewriter &rewriter)
const override {
146 SmallVector<AffineMap> maps =
147 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
148 Value
lhs = contractOp.getLhs();
149 Value
rhs = contractOp.getRhs();
151 bool changed =
false;
152 for (Value *operand : {&
lhs, &
rhs}) {
153 AffineMap &map = maps[index++];
154 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
158 transposeOp.getPermutation(), contractOp.getContext());
160 *operand = transposeOp.getVector();
166 contractOp,
lhs,
rhs, contractOp.getAcc(),
204struct CombineContractResultTranspose final
208 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
209 PatternRewriter &rewriter)
const override {
210 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
211 if (!contractOp || !contractOp->hasOneUse())
214 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
218 MLIRContext *context = contractOp.getContext();
219 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
220 AffineMap contractMap = maps.back();
231 auto combinedResMap = resTMap.compose(contractMap);
238 maps.back() = combinedResMap;
241 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
274FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
275 MaskingOpInterface maskingOp,
278 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
282 bool changed =
false;
285 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
289 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
291 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
294 broadcast.getResultVectorType().getRank() - srcType.getRank();
295 bool innerDimBroadcast =
false;
297 for (
const auto &dim : llvm::enumerate(srcType.getShape())) {
299 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
300 innerDimBroadcast =
true;
307 if (innerDimBroadcast)
312 bool nonUnitDimReductionBroadcast =
false;
313 for (
int64_t i = 0; i < rankDiff; ++i) {
314 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
317 nonUnitDimReductionBroadcast =
true;
321 if (nonUnitDimReductionBroadcast)
326 originalDims, contractOp.getContext());
327 map = broadcastMap.
compose(map);
343 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
344 if (!unusedDimsBitVector.test(i))
345 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
351 VectorType oldMaskType;
352 bool isAnyUnusedDimNonUnit =
false;
354 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
355 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
356 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
357 isAnyUnusedDimNonUnit =
true;
368 bool hasReductionIteratorApplyingOnBothSides =
false;
369 for (
unsigned i = 0; i < iterators.size(); ++i) {
373 hasReductionIteratorApplyingOnBothSides =
true;
377 if (!hasReductionIteratorApplyingOnBothSides)
385 Operation *newOp = vector::ContractionOp::create(
386 rewriter, contractOp.getLoc(),
lhs,
rhs, contractOp.getAcc(),
391 if (isAnyUnusedDimNonUnit)
393 "Cannont drop non-unit mask dim.");
394 assert(unusedDimsBitVector.size() ==
395 static_cast<size_t>(oldMaskType.getRank()) &&
396 "The mask rank is incorrect!");
400 Value mask = maskingOp.getMask();
401 if (unusedDimsBitVector.count() != 0) {
409 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
410 auto newShapeScalableDims =
411 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
412 VectorType maskOpType =
413 VectorType::get(newShape, rewriter.
getI1Type(), newShapeScalableDims);
414 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
415 maskOpType, maskingOp.getMask())
424struct CombineContractBroadcastMask
426 using MaskableOpRewritePattern::MaskableOpRewritePattern;
429 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
430 MaskingOpInterface maskingOp,
431 PatternRewriter &rewriter)
const override {
432 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
449struct ReorderCastOpsOnBroadcast
451 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
453 LogicalResult matchAndRewrite(CastOpInterface op,
454 PatternRewriter &rewriter)
const override {
455 if (op->getNumOperands() != 1)
457 if (!isa<VectorType>(op->getResult(0).getType()))
459 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
464 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
465 castResTy = vecTy.clone(castResTy);
467 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
468 bcastOp.getSource(), castResTy, op->getAttrs());
470 op, op->getResult(0).
getType(), castOp->getResult(0));
489struct ReorderElementwiseOpsOnTranspose final
492 LogicalResult matchAndRewrite(Operation *op,
493 PatternRewriter &rewriter)
const override {
499 SmallVector<ArrayRef<int64_t>> transposeMaps;
505 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
507 transposeMaps.push_back(transposeOp.getPermutation());
508 srcType = transposeOp.getSourceVectorType();
513 if (transposeMaps.empty())
518 if (!llvm::all_equal(transposeMaps))
521 SmallVector<Value> srcValues;
526 auto order = transposeMaps.front();
527 SmallVector<int64_t> invOrder(order.size());
528 for (
int i = 0, e = order.size(); i < e; ++i)
529 invOrder[order[i]] = i;
532 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
534 srcValues.push_back(transposeOp.getVector());
538 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
539 srcValues.push_back(vector::TransposeOp::create(
540 rewriter, operand.getLoc(), vectorType, operand, invOrder));
544 auto vectorType = srcType.clone(
546 Operation *elementwiseOp =
551 transposeMaps.front());
558 return llvm::map_to_vector<4>(arrayAttr.getAsRange<IntegerAttr>(),
559 [](IntegerAttr attr) { return attr.getInt(); });
571struct BubbleDownVectorBitCastForExtract
575 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
576 PatternRewriter &rewriter)
const override {
578 if (extractOp.getSourceVectorType().getRank() != 1)
581 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
585 VectorType castSrcType = castOp.getSourceVectorType();
586 VectorType castDstType = castOp.getResultVectorType();
587 assert(castSrcType.getRank() == castDstType.getRank());
592 if (castSrcType.getNumElements() == 1)
597 if (castSrcType.getNumElements() > castDstType.getNumElements())
600 unsigned expandRatio =
601 castDstType.getNumElements() / castSrcType.getNumElements();
604 auto mixedPos = extractOp.getMixedPosition();
605 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
607 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
611 Location loc = extractOp.getLoc();
612 Value packedValue = vector::ExtractOp::create(
613 rewriter, loc, castOp.getSource(), index / expandRatio);
614 Type packedVecType = VectorType::get({1}, packedValue.
getType());
615 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
617 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
622 VectorType packedType =
623 VectorType::get({expandRatio}, castDstType.getElementType());
625 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
629 index % expandRatio);
646struct BubbleDownBitCastForStridedSliceExtract
650 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
651 PatternRewriter &rewriter)
const override {
652 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
656 VectorType castSrcType = castOp.getSourceVectorType();
657 VectorType castDstType = castOp.getResultVectorType();
658 assert(castSrcType.getRank() == castDstType.getRank());
660 int64_t castSrcLastDim = castSrcType.getShape().back();
661 int64_t castDstLastDim = castDstType.getShape().back();
663 if (castSrcLastDim > castDstLastDim)
667 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
668 [](
const APInt &val) { return !val.isOne(); }))
671 unsigned rank = extractOp.getSourceVectorType().getRank();
672 assert(castDstLastDim % castSrcLastDim == 0);
673 int64_t expandRatio = castDstLastDim / castSrcLastDim;
679 ArrayAttr newOffsets = extractOp.getOffsets();
680 if (newOffsets.size() == rank) {
681 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
682 if (offsets.back() % expandRatio != 0)
684 offsets.back() = offsets.back() / expandRatio;
689 ArrayAttr newSizes = extractOp.getSizes();
690 if (newSizes.size() == rank) {
691 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
692 if (sizes.back() % expandRatio != 0)
694 sizes.back() = sizes.back() / expandRatio;
698 SmallVector<int64_t> dims =
699 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
700 dims.back() = dims.back() / expandRatio;
701 VectorType newExtractType =
702 VectorType::get(dims, castSrcType.getElementType());
704 auto newExtractOp = vector::ExtractStridedSliceOp::create(
705 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
706 newOffsets, newSizes, extractOp.getStrides());
709 extractOp, extractOp.getType(), newExtractOp);
725struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
728 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
729 PatternRewriter &rewriter)
const override {
730 VectorType castSrcType = bitcastOp.getSourceVectorType();
731 VectorType castDstType = bitcastOp.getResultVectorType();
734 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
735 castDstType.isScalable())
738 int64_t castSrcLastDim = castSrcType.getShape().back();
739 int64_t castDstLastDim = castDstType.getShape().back();
740 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
742 if (isNumElemsShrink) {
743 assert(castSrcLastDim % castDstLastDim == 0);
744 ratio = castSrcLastDim / castDstLastDim;
746 assert(castDstLastDim % castSrcLastDim == 0);
747 ratio = castDstLastDim / castSrcLastDim;
750 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
755 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
760 SmallVector<int64_t> srcDims(insertSrcType.getShape());
762 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
763 VectorType newCastSrcType =
764 VectorType::get(srcDims, castDstType.getElementType());
766 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
767 insertOp.getValueToStore());
769 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
771 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
772 VectorType newCastDstType =
773 VectorType::get(dstDims, castDstType.getElementType());
776 auto newCastDstOp = vector::BitCastOp::create(
777 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
781 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
797struct BubbleUpBitCastForStridedSliceInsert
801 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
802 PatternRewriter &rewriter)
const override {
803 VectorType castSrcType = bitcastOp.getSourceVectorType();
804 VectorType castDstType = bitcastOp.getResultVectorType();
805 assert(castSrcType.getRank() == castDstType.getRank());
807 if (castSrcType.getRank() == 0)
810 int64_t castSrcLastDim = castSrcType.getShape().back();
811 int64_t castDstLastDim = castDstType.getShape().back();
813 if (castSrcLastDim < castDstLastDim)
816 assert(castSrcLastDim % castDstLastDim == 0);
817 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
820 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
825 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
826 [](
const APInt &val) { return !val.isOne(); }))
829 unsigned rank = insertOp.getSourceVectorType().getRank();
832 if (rank != insertOp.getDestVectorType().getRank())
836 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
837 unsigned destinationWidth =
838 castDstType.getElementType().getIntOrFloatBitWidth();
839 unsigned numElements = destinationWidth / sourceWidth;
840 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
843 ArrayAttr newOffsets = insertOp.getOffsets();
844 assert(newOffsets.size() == rank);
845 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
846 if (offsets.back() % shrinkRatio != 0)
848 offsets.back() = offsets.back() / shrinkRatio;
851 SmallVector<int64_t> srcDims =
852 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
853 srcDims.back() = srcDims.back() / shrinkRatio;
854 VectorType newCastSrcType =
855 VectorType::get(srcDims, castDstType.getElementType());
858 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
859 insertOp.getValueToStore());
861 SmallVector<int64_t> dstDims =
862 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
863 dstDims.back() = dstDims.back() / shrinkRatio;
864 VectorType newCastDstType =
865 VectorType::get(dstDims, castDstType.getElementType());
867 auto newCastDstOp = vector::BitCastOp::create(
868 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
871 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
872 insertOp.getStrides());
900 BreakDownVectorBitCast(MLIRContext *context,
901 std::function<
bool(vector::BitCastOp)> controlFn,
902 PatternBenefit benefit)
903 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
905 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
906 PatternRewriter &rewriter)
const override {
908 if (controlFn && !controlFn(bitcastOp))
911 VectorType castSrcType = bitcastOp.getSourceVectorType();
912 VectorType castDstType = bitcastOp.getResultVectorType();
913 assert(castSrcType.getRank() == castDstType.getRank());
918 if (castSrcType.isScalable())
920 "Scalable vectors are not supported");
923 if (castSrcType.getRank() != 1)
926 int64_t castSrcLastDim = castSrcType.getShape().back();
927 int64_t castDstLastDim = castDstType.getShape().back();
929 if (castSrcLastDim < castDstLastDim)
932 assert(castSrcLastDim % castDstLastDim == 0);
933 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
935 if (castSrcLastDim == shrinkRatio)
938 Location loc = bitcastOp.getLoc();
939 Type elemType = castDstType.getElementType();
942 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
944 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
946 SmallVector<int64_t> sliceShape = {castDstLastDim};
947 SmallVector<int64_t> strides = {1};
948 VectorType newCastDstType =
949 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
950 castDstType.getElementType());
952 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
953 Value extracted = ExtractStridedSliceOp::create(
954 rewriter, loc, bitcastOp.getSource(),
955 ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
957 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
958 res = InsertStridedSliceOp::create(
959 rewriter, loc, bitcast, res,
960 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
967 std::function<bool(BitCastOp)> controlFn;
970static bool haveSameShapeAndScaling(
Type t,
Type u) {
971 auto tVec = dyn_cast<VectorType>(t);
972 auto uVec = dyn_cast<VectorType>(u);
979 return tVec.getShape() == uVec.getShape() &&
980 tVec.getScalableDims() == uVec.getScalableDims();
985static Type cloneOrReplace(
Type type,
Type newElementType) {
986 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
987 return shapedType.clone(newElementType);
989 return newElementType;
994static Value getBroadcastLikeSource(
Value value) {
1000 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1019struct ReorderElementwiseOpsOnBroadcast final
1022 LogicalResult matchAndRewrite(Operation *op,
1023 PatternRewriter &rewriter)
const override {
1031 op,
"Op doesn't have ElementwiseMappableTraits");
1034 if (isa<vector::FMAOp>(op)) {
1037 "Op only accepts vector types - not supported as broadcast source "
1038 "might be a scalar");
1041 Type resultElemType = resultType.getElementType();
1044 Value broadcastSource;
1046 Operation *definingOp = operand.getDefiningOp();
1049 if (definingOp->
hasTrait<OpTrait::ConstantLike>())
1051 broadcastSource = getBroadcastLikeSource(operand);
1054 if (!broadcastSource)
1056 Type unbroadcastResultType =
1057 cloneOrReplace(broadcastSource.
getType(), resultElemType);
1063 if (!llvm::all_of(op->
getOperands(), [broadcastSource](Value val) {
1064 if (auto source = getBroadcastLikeSource(val))
1065 return haveSameShapeAndScaling(source.getType(),
1066 broadcastSource.getType());
1067 SplatElementsAttr splatConst;
1068 return matchPattern(val, m_Constant(&splatConst));
1072 "not all operands are constants or broadcasts from the same type");
1076 SmallVector<Value> srcValues;
1079 SplatElementsAttr splatConst;
1083 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1084 if (
auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1089 Operation *newConstOp =
1090 operand.getDefiningOp()->getDialect()->materializeConstant(
1091 rewriter, newConst, newType, operand.getLoc());
1092 srcValues.push_back(newConstOp->
getResult(0));
1094 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1099 Operation *elementwiseOp =
1101 unbroadcastResultType, op->
getAttrs());
1105 op, resultType, elementwiseOp->
getResults());
1127class ExtractOpFromElementwise final
1132 LogicalResult matchAndRewrite(vector::ExtractOp op,
1133 PatternRewriter &rewriter)
const override {
1134 Operation *eltwise = op.getSource().getDefiningOp();
1139 isa<vector::FMAOp>(eltwise))
1153 if (!op.getDynamicPosition().empty())
1155 op,
"dynamic position not yet implemented");
1157 Type dstType = op.getType();
1159 OpBuilder::InsertionGuard g(rewriter);
1163 Location loc = eltwise->
getLoc();
1164 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1166 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1167 mapping.
map(arg, newArg);
1170 Operation *newEltwise = rewriter.
clone(*eltwise, mapping);
1181static bool isSupportedMemSinkElementType(
Type type) {
1182 if (isa<IndexType>(type))
1203class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1208 PatternRewriter &rewriter)
const override {
1209 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1214 if (!loadOp->hasOneUse())
1217 VectorType loadVecType = loadOp.getVectorType();
1218 if (loadVecType.isScalable())
1220 "scalable vectors are not supported");
1222 MemRefType memType = loadOp.getMemRefType();
1226 if (!isSupportedMemSinkElementType(memType.getElementType()))
1229 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1233 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1234 int64_t finalRank = 0;
1236 finalRank = extractVecType.getRank();
1238 SmallVector<Value>
indices = loadOp.getIndices();
1239 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1244 OpBuilder::InsertionGuard g(rewriter);
1246 Location loc = loadOp.getLoc();
1247 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1248 for (
auto i : llvm::seq<int64_t>(rankOffset,
indices.size() - finalRank)) {
1249 OpFoldResult pos = extractPos[i - rankOffset];
1257 Value base = loadOp.getBase();
1258 if (extractVecType) {
1281class StoreOpFromBroadcast final :
public OpRewritePattern<vector::StoreOp> {
1286 PatternRewriter &rewriter)
const override {
1287 VectorType vecType = op.getVectorType();
1288 if (vecType.isScalable())
1290 "scalable vectors are not supported");
1292 if (isa<VectorType>(op.getMemRefType().getElementType()))
1294 op,
"memrefs of vectors are not supported");
1296 if (vecType.getNumElements() != 1)
1298 op,
"only 1-element vectors are supported");
1300 Value toStore = op.getValueToStore();
1301 Value source = getBroadcastLikeSource(toStore);
1304 op,
"value to store is not from a broadcast");
1311 Value base = op.getBase();
1314 if (isa<VectorType>(source.
getType())) {
1334 bool force32BitVectorIndices,
int64_t dim,
1343 if (dim == 0 && force32BitVectorIndices) {
1346 }
else if (dim == 0) {
1349 }
else if (force32BitVectorIndices) {
1351 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1354 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1356 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1360 Value ov = vector::BroadcastOp::create(rewriter, loc,
indices.getType(), o);
1369 if (force32BitVectorIndices) {
1372 b = arith::MinSIOp::create(rewriter, loc,
b, maxBound);
1376 vector::BroadcastOp::create(rewriter, loc,
indices.getType(), bound);
1377 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1381template <
typename ConcreteOp>
1384 explicit MaterializeTransferMask(MLIRContext *context,
bool enableIndexOpt,
1385 PatternBenefit benefit = 1)
1386 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1387 force32BitVectorIndices(enableIndexOpt) {}
1390 PatternRewriter &rewriter)
const override {
1391 if (!xferOp.hasOutOfBoundsDim())
1394 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1397 Location loc = xferOp->getLoc();
1398 VectorType vtp = xferOp.getVectorType();
1405 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1406 Value off = xferOp.getIndices()[lastIndex];
1409 Value
b = arith::SubIOp::create(rewriter, loc, dim.
getType(), dim, off);
1410 Value mask = vector::CreateMaskOp::create(
1412 VectorType::get(vtp.getShape(), rewriter.
getI1Type(),
1413 vtp.getScalableDims()),
1415 if (xferOp.getMask()) {
1417 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1421 xferOp.getMaskMutable().assign(mask);
1429 const bool force32BitVectorIndices;
1433class VectorCreateMaskOpConversion
1436 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1437 bool enableIndexOpt,
1438 PatternBenefit benefit = 1)
1439 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1440 force32BitVectorIndices(enableIndexOpt) {}
1442 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1443 PatternRewriter &rewriter)
const override {
1444 auto dstType = op.getType();
1445 if (cast<VectorType>(dstType).isScalable())
1447 int64_t rank = dstType.getRank();
1451 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1452 rank == 0 ? 0 : dstType.getDimSize(0),
1458 const bool force32BitVectorIndices;
1462static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1463 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1468 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1469 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1487 PatternRewriter &rewriter)
const override {
1488 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1489 if (!vecType || !vecType.getElementType().isInteger(1))
1493 Value cond = selectOp.getCondition();
1494 if (isa<VectorType>(cond.
getType()))
1498 if (vecType.getRank() != 1 || vecType.isScalable())
1502 if (vecType.getShape()[0] != 1)
1505 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1506 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1510 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1511 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1515 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1516 auto bcastType = VectorType::get({1}, elemType);
1537static FailureOr<size_t>
1541 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1544 auto isUnitDim = [](VectorType type,
int dim) {
1545 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1552 int rankDiff = srcType.getRank() - vectorType.getRank();
1553 for (
int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1556 int dim = vectorType.getRank() - i - 1;
1557 if (srcStrides[dim + rankDiff] != 1 ||
1558 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1570 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1573 if (readOp.getTransferRank() == 0)
1576 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1580 if (!readOp.getPermutationMap().isMinorIdentity())
1583 auto targetType = readOp.getVectorType();
1584 if (targetType.getRank() <= 1)
1587 FailureOr<size_t> maybeDimsToDrop =
1589 if (failed(maybeDimsToDrop))
1592 size_t dimsToDrop = maybeDimsToDrop.value();
1593 if (dimsToDrop == 0)
1596 auto inBounds = readOp.getInBoundsValues();
1597 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1598 if (llvm::is_contained(droppedInBounds,
false))
1601 auto resultTargetVecType =
1602 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1603 targetType.getElementType(),
1604 targetType.getScalableDims().drop_back(dimsToDrop));
1606 auto loc = readOp.getLoc();
1613 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1614 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1617 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1618 Value rankedReducedView =
1619 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1620 readOp.getBase(), offsets, sizes, strides);
1622 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1625 Value mask = readOp.getMask();
1627 auto maskType = cast<VectorType>(mask.
getType());
1628 auto reducedMaskType = VectorType::get(
1629 maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
1630 maskType.getScalableDims().drop_back(dimsToDrop));
1631 mask = rewriter.
createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
1636 rewriter, loc, resultTargetVecType, rankedReducedView,
1637 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1638 readOp.getPadding(), mask, inBoundsAttr);
1667 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1670 if (writeOp.getTransferRank() == 0)
1673 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1677 if (!writeOp.getPermutationMap().isMinorIdentity())
1680 auto targetType = writeOp.getVectorType();
1681 if (targetType.getRank() <= 1)
1684 FailureOr<size_t> maybeDimsToDrop =
1686 if (failed(maybeDimsToDrop))
1689 size_t dimsToDrop = maybeDimsToDrop.value();
1690 if (dimsToDrop == 0)
1693 auto inBounds = writeOp.getInBoundsValues();
1694 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1695 if (llvm::is_contained(droppedInBounds,
false))
1698 auto resultTargetVecType =
1699 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1700 targetType.getElementType(),
1701 targetType.getScalableDims().drop_back(dimsToDrop));
1710 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1711 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1714 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1716 Value rankedReducedView =
1717 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1718 writeOp.getBase(), offsets, sizes, strides);
1720 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1722 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1723 loc, resultTargetVecType, writeOp.getVector());
1726 Value mask = writeOp.getMask();
1728 auto maskType = cast<VectorType>(mask.
getType());
1729 auto reducedMaskType = VectorType::get(
1730 maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
1731 maskType.getScalableDims().drop_back(dimsToDrop));
1732 mask = rewriter.
createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
1737 writeOp, shapeCast, rankedReducedView,
1738 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1739 mask, inBoundsAttr);
1752 std::function<LogicalResult(vector::ContractionOp op)>;
1757 filter(std::move(constraint)) {}
1761 if (failed(filter(op)))
1767 Value res = op.getAcc();
1771 auto infer = [&](MapList m) {
1778 static constexpr std::array<int64_t, 2> perm = {1, 0};
1779 auto iteratorTypes = op.getIteratorTypes().getValue();
1781 if (iteratorTypes.size() != 3 ||
1788 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1789 if (maps == canonicalForm)
1794 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1795 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1797 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1798 VectorType newType =
1799 cast<VectorType>(trans.
getType())
1800 .clone(cast<VectorType>(mat.getType()).getElementType());
1801 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1803 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1805 vector::TransposeOp::create(rewriter, loc,
zext.getIn(), perm);
1806 VectorType newType =
1807 VectorType::get(cast<VectorType>(trans.
getType()).getShape(),
1808 cast<VectorType>(mat.getType()).getElementType());
1809 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1811 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1814 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1815 rhs = createTranspose(
rhs);
1816 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1817 lhs = createTranspose(
lhs);
1818 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1819 rhs = createTranspose(
rhs);
1820 lhs = createTranspose(
lhs);
1821 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1823 rhs = createTranspose(
rhs);
1824 lhs = createTranspose(
lhs);
1825 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1827 rhs = createTranspose(
rhs);
1828 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1830 lhs = createTranspose(
lhs);
1831 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1838 op.getIteratorTypes());
1863template <
typename ExtOp>
1871 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1872 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1874 if (!lhsDefOp || !rhsDefOp) {
1876 "no defining op on contract operands");
1880 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1881 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1882 contractOp.getIteratorTypesAttr());
1904 if (op.getKind() != vector::CombiningKind::ADD)
1912 if (!
acc.getType().isIntOrFloat())
1915 auto parentReduction =
acc.getDefiningOp<vector::ReductionOp>();
1916 if (!parentReduction)
1921 if (isa<IntegerType>(
acc.getType())) {
1923 loc, parentReduction.getVector(), op.getVector());
1925 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1929 parentReduction.getAcc());
1940 auto inVecShape = inVecTy.getShape();
1943 for (
auto [dim, isScalable] :
1944 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1945 if (dim == 1 && !isScalable)
1948 newShape.push_back(dim);
1949 newScalableDims.push_back(isScalable);
1952 if (newShape.empty()) {
1953 newShape.push_back(1);
1954 newScalableDims.push_back(
false);
1957 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1994 if (!resultVectorType)
2001 if (!sourceVectorType)
2003 if (sourceVectorType.getRank() < 2)
2009 auto opVectorType = cast<VectorType>(operand.getType());
2011 if (newVType == opVectorType)
2014 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
2015 newOperands.push_back(opSC);
2018 VectorType newResultVectorType =
2023 newResultVectorType, op->
getAttrs());
2058 VectorType sourceType = op.getSourceVectorType();
2059 VectorType sourceTypeWithoutUnitDims =
2062 if (sourceType == sourceTypeWithoutUnitDims)
2069 for (
auto [i, dim] : llvm::enumerate(sourceDims)) {
2070 droppedDimsBefore[i] = droppedDims;
2071 if (dim == std::make_tuple(1,
false))
2079 if (sourceDims[idx] == std::make_tuple(1,
false))
2081 newPerm.push_back(idx - droppedDimsBefore[idx]);
2087 if (newPerm.empty()) {
2088 newPerm.push_back(0);
2093 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2094 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2096 auto transposeWithoutUnitDims =
2097 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2100 op, op.getResultVectorType(), transposeWithoutUnitDims);
2137 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
2138 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2143 if (vectorType == newVectorType)
2148 return vector::ShapeCastOp::create(
b, loc, type, source);
2152 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2154 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2181 if (op.getKind() != vector::CombiningKind::ADD)
2184 Type elemType = op.getSourceVectorType().getElementType();
2187 if (!isa<FloatType>(elemType))
2190 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2193 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2200 auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2201 addLhs.getLhs(), vAdd.getRhs());
2220 unsigned maxNumElementsToExtract,
2223 maxNumElementsToExtract(maxNumElementsToExtract) {}
2227 VectorType type = op.getSourceVectorType();
2228 if (type.isScalable() || op.isMasked())
2230 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2232 int64_t numElems = type.getNumElements();
2233 if (numElems > maxNumElementsToExtract) {
2235 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2236 "(max allowed: {1})",
2237 numElems, maxNumElementsToExtract));
2242 for (
auto [idx, extractedElem] : llvm::enumerate(extracted))
2243 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2246 Value res = extracted.front();
2247 for (
auto extractedElem : llvm::drop_begin(extracted))
2249 extractedElem, op.getFastmathAttr());
2252 op.getFastmathAttr());
2259 unsigned maxNumElementsToExtract = 0;
2278template <
typename MulOpType>
2283 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2286 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2289 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2290 return srcType && srcType.getRank() != 2;
2295 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2298 if (resType.getRank() != 2)
2303 auto matchOuterProduct =
2305 Value operandB) -> FailureOr<vector::OuterProductOp> {
2306 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2311 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2314 auto broadcastedLhs =
2315 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2316 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2319 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2320 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2323 return vector::OuterProductOp::create(
2324 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2325 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2328 Value lhs = mulOp->getOperand(0),
rhs = mulOp->getOperand(1);
2329 auto maybeOuterP = matchOuterProduct(
lhs,
rhs);
2331 if (failed(maybeOuterP))
2332 maybeOuterP = matchOuterProduct(
rhs,
lhs);
2333 if (failed(maybeOuterP))
2335 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2349void mlir::vector::populateVectorMaskMaterializationPatterns(
2352 patterns.
add<VectorCreateMaskOpConversion,
2353 MaterializeTransferMask<vector::TransferReadOp>,
2354 MaterializeTransferMask<vector::TransferWriteOp>>(
2355 patterns.
getContext(), force32BitVectorIndices, benefit);
2359void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2365void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2367 patterns.
add<BubbleDownVectorBitCastForExtract,
2368 BubbleDownBitCastForStridedSliceExtract,
2369 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2373void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2375 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2377 std::move(controlFn), benefit);
2382 std::function<LogicalResult(vector::ContractionOp)> constraint,
2385 std::move(constraint));
2390 patterns.
add<MultiReduceToContract, CombineContractBroadcastMask,
2391 CombineContractABTranspose, CombineContractResultTranspose>(
2404 patterns.
add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2405 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2412 patterns.
add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.
getContext(),
2416void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2423void mlir::vector::populateBreakDownVectorReductionPatterns(
2427 maxNumElementsToExtract, benefit);
2432 patterns.
add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2433 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2441#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Drop inner most contiguous unit dimensions from transfer_read operand.
Drop inner most contiguous unit dimensions from transfer_write operand. E.g., vector....
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
BreakDownVectorReduction(MLIRContext *context, unsigned maxNumElementsToExtract, PatternBenefit benefit)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction suitab...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, FilterConstraintType constraint)
Pattern to fold chained reduction to a series of vector additions and a final reduction....
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b =...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
A pattern to drop unit dims from the iter_args of an scf.for.
LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override
A pattern to drop unit dims from vector.transpose.
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Pattern to fold arithmetic extensions on floating point data types into vector contraction operations...
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
Pattern to eliminate redundant zero-constants added to reduction operands. It's enough for there to b...
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.