35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/FormatVariadic.h"
38 #define DEBUG_TYPE "vector-to-vector"
43 template <
typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
79 struct MultiReduceToContract
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
99 iteratorTypes.push_back(vector::IteratorType::reduction);
104 0, exprs, reduceOp.getContext());
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
139 struct CombineContractABTranspose final
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
151 for (
Value *operand : {&lhs, &rhs}) {
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
157 transposeOp.getPermutation(), contractOp.getContext());
159 *operand = transposeOp.getVector();
165 contractOp, lhs, rhs, contractOp.getAcc(),
203 struct CombineContractResultTranspose final
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
230 auto combinedResMap = resTMap.compose(contractMap);
237 maps.back() = combinedResMap;
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
273 FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278 Value lhs = contractOp.getLhs();
279 Value rhs = contractOp.getRhs();
282 for (
Value *operand : {&lhs, &rhs}) {
288 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
290 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast =
false;
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast =
true;
306 if (innerDimBroadcast)
311 bool nonUnitDimReductionBroadcast =
false;
312 for (int64_t i = 0; i < rankDiff; ++i) {
313 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
316 nonUnitDimReductionBroadcast =
true;
320 if (nonUnitDimReductionBroadcast)
326 map = broadcastMap.
compose(map);
342 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(i))
344 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit =
false;
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit =
true;
367 bool hasReductionIteratorApplyingOnBothSides =
false;
368 for (
unsigned i = 0; i < iterators.size(); ++i) {
372 hasReductionIteratorApplyingOnBothSides =
true;
376 if (!hasReductionIteratorApplyingOnBothSides)
384 Operation *newOp = vector::ContractionOp::create(
385 rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
390 if (isAnyUnusedDimNonUnit)
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
413 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
414 maskOpType, maskingOp.getMask())
423 struct CombineContractBroadcastMask
425 using MaskableOpRewritePattern::MaskableOpRewritePattern;
428 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
429 MaskingOpInterface maskingOp,
431 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
448 struct ReorderCastOpsOnBroadcast
452 LogicalResult matchAndRewrite(CastOpInterface op,
454 if (op->getNumOperands() != 1)
456 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
461 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
462 castResTy = vecTy.clone(castResTy);
464 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
465 bcastOp.getSource(), castResTy, op->getAttrs());
467 op, op->getResult(0).getType(), castOp->getResult(0));
486 struct ReorderElementwiseOpsOnTranspose final
489 LogicalResult matchAndRewrite(
Operation *op,
502 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
504 transposeMaps.push_back(transposeOp.getPermutation());
505 srcType = transposeOp.getSourceVectorType();
510 if (transposeMaps.empty())
515 if (!llvm::all_equal(transposeMaps))
523 auto order = transposeMaps.front();
525 for (
int i = 0, e = order.size(); i < e; ++i)
526 invOrder[order[i]] = i;
529 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
531 srcValues.push_back(transposeOp.getVector());
535 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
536 srcValues.push_back(vector::TransposeOp::create(
537 rewriter, operand.getLoc(), vectorType, operand, invOrder));
541 auto vectorType = srcType.clone(
548 transposeMaps.front());
555 return llvm::to_vector<4>(
556 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
557 [](IntegerAttr attr) { return attr.getInt(); }));
569 struct BubbleDownVectorBitCastForExtract
573 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
576 if (extractOp.getSourceVectorType().getRank() != 1)
579 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
583 VectorType castSrcType = castOp.getSourceVectorType();
584 VectorType castDstType = castOp.getResultVectorType();
585 assert(castSrcType.getRank() == castDstType.getRank());
590 if (castSrcType.getNumElements() == 1)
595 if (castSrcType.getNumElements() > castDstType.getNumElements())
598 unsigned expandRatio =
599 castDstType.getNumElements() / castSrcType.getNumElements();
602 auto mixedPos = extractOp.getMixedPosition();
603 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
605 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
610 Value packedValue = vector::ExtractOp::create(
611 rewriter, loc, castOp.getSource(), index / expandRatio);
613 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
615 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
620 VectorType packedType =
623 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
627 index % expandRatio);
644 struct BubbleDownBitCastForStridedSliceExtract
648 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
650 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
654 VectorType castSrcType = castOp.getSourceVectorType();
655 VectorType castDstType = castOp.getResultVectorType();
656 assert(castSrcType.getRank() == castDstType.getRank());
658 int64_t castSrcLastDim = castSrcType.getShape().back();
659 int64_t castDstLastDim = castDstType.getShape().back();
661 if (castSrcLastDim > castDstLastDim)
665 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666 [](
const APInt &val) { return !val.isOne(); }))
669 unsigned rank = extractOp.getSourceVectorType().getRank();
670 assert(castDstLastDim % castSrcLastDim == 0);
671 int64_t expandRatio = castDstLastDim / castSrcLastDim;
677 ArrayAttr newOffsets = extractOp.getOffsets();
678 if (newOffsets.size() == rank) {
680 if (offsets.back() % expandRatio != 0)
682 offsets.back() = offsets.back() / expandRatio;
687 ArrayAttr newSizes = extractOp.getSizes();
688 if (newSizes.size() == rank) {
690 if (sizes.back() % expandRatio != 0)
692 sizes.back() = sizes.back() / expandRatio;
697 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698 dims.back() = dims.back() / expandRatio;
699 VectorType newExtractType =
702 auto newExtractOp = vector::ExtractStridedSliceOp::create(
703 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
704 newOffsets, newSizes, extractOp.getStrides());
707 extractOp, extractOp.getType(), newExtractOp);
723 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
726 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
728 VectorType castSrcType = bitcastOp.getSourceVectorType();
729 VectorType castDstType = bitcastOp.getResultVectorType();
732 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733 castDstType.isScalable())
736 int64_t castSrcLastDim = castSrcType.getShape().back();
737 int64_t castDstLastDim = castDstType.getShape().back();
738 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
740 if (isNumElemsShrink) {
741 assert(castSrcLastDim % castDstLastDim == 0);
742 ratio = castSrcLastDim / castDstLastDim;
744 assert(castDstLastDim % castSrcLastDim == 0);
745 ratio = castDstLastDim / castSrcLastDim;
748 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
753 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
760 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761 VectorType newCastSrcType =
764 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
765 insertOp.getValueToStore());
769 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770 VectorType newCastDstType =
774 auto newCastDstOp = vector::BitCastOp::create(
775 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
779 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
795 struct BubbleUpBitCastForStridedSliceInsert
799 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
801 VectorType castSrcType = bitcastOp.getSourceVectorType();
802 VectorType castDstType = bitcastOp.getResultVectorType();
803 assert(castSrcType.getRank() == castDstType.getRank());
805 if (castSrcType.getRank() == 0)
808 int64_t castSrcLastDim = castSrcType.getShape().back();
809 int64_t castDstLastDim = castDstType.getShape().back();
811 if (castSrcLastDim < castDstLastDim)
814 assert(castSrcLastDim % castDstLastDim == 0);
815 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
818 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
823 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824 [](
const APInt &val) { return !val.isOne(); }))
827 unsigned rank = insertOp.getSourceVectorType().getRank();
830 if (rank != insertOp.getDestVectorType().getRank())
834 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835 unsigned destinationWidth =
836 castDstType.getElementType().getIntOrFloatBitWidth();
837 unsigned numElements = destinationWidth / sourceWidth;
838 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
841 ArrayAttr newOffsets = insertOp.getOffsets();
842 assert(newOffsets.size() == rank);
844 if (offsets.back() % shrinkRatio != 0)
846 offsets.back() = offsets.back() / shrinkRatio;
850 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
851 srcDims.back() = srcDims.back() / shrinkRatio;
852 VectorType newCastSrcType =
856 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
857 insertOp.getValueToStore());
860 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
861 dstDims.back() = dstDims.back() / shrinkRatio;
862 VectorType newCastDstType =
865 auto newCastDstOp = vector::BitCastOp::create(
866 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
869 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
870 insertOp.getStrides());
894 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
899 std::function<
bool(vector::BitCastOp)> controlFn,
903 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
906 if (controlFn && !controlFn(bitcastOp))
909 VectorType castSrcType = bitcastOp.getSourceVectorType();
910 VectorType castDstType = bitcastOp.getResultVectorType();
911 assert(castSrcType.getRank() == castDstType.getRank());
916 if (castSrcType.isScalable())
918 "Scalable vectors are not supported");
921 if (castSrcType.getRank() != 1)
924 int64_t castSrcLastDim = castSrcType.getShape().back();
925 int64_t castDstLastDim = castDstType.getShape().back();
927 if (castSrcLastDim < castDstLastDim)
930 assert(castSrcLastDim % castDstLastDim == 0);
931 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
933 if (castSrcLastDim == shrinkRatio)
937 Type elemType = castDstType.getElementType();
940 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
942 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
946 VectorType newCastDstType =
948 castDstType.getElementType());
950 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
951 Value extracted = ExtractStridedSliceOp::create(
952 rewriter, loc, bitcastOp.getSource(),
955 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
956 res = InsertStridedSliceOp::create(
957 rewriter, loc, bitcast, res,
965 std::function<bool(BitCastOp)> controlFn;
968 static bool haveSameShapeAndScaling(
Type t,
Type u) {
969 auto tVec = dyn_cast<VectorType>(t);
970 auto uVec = dyn_cast<VectorType>(u);
977 return tVec.getShape() == uVec.getShape() &&
978 tVec.getScalableDims() == uVec.getScalableDims();
983 static Type cloneOrReplace(
Type type,
Type newElementType) {
984 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
985 return shapedType.clone(newElementType);
987 return newElementType;
992 static Value getBroadcastLikeSource(
Value value) {
998 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1017 struct ReorderElementwiseOpsOnBroadcast final
1020 LogicalResult matchAndRewrite(
Operation *op,
1029 op,
"Op doesn't have ElementwiseMappableTraits");
1032 if (isa<vector::FMAOp>(op)) {
1035 "Op only accepts vector types - not supported as broadcast source "
1036 "might be a scalar");
1039 Type resultElemType = resultType.getElementType();
1042 Value broadcastSource;
1044 Operation *definingOp = operand.getDefiningOp();
1049 broadcastSource = getBroadcastLikeSource(operand);
1052 if (!broadcastSource)
1054 Type unbroadcastResultType =
1055 cloneOrReplace(broadcastSource.
getType(), resultElemType);
1062 if (auto source = getBroadcastLikeSource(val))
1063 return haveSameShapeAndScaling(source.getType(),
1064 broadcastSource.getType());
1065 SplatElementsAttr splatConst;
1066 return matchPattern(val, m_Constant(&splatConst));
1070 "not all operands are constants or broadcasts from the same type");
1081 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1082 if (
auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1089 rewriter, newConst, newType, operand.getLoc());
1090 srcValues.push_back(newConstOp->
getResult(0));
1092 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1099 unbroadcastResultType, op->
getAttrs());
1103 op, resultType, elementwiseOp->
getResults());
1125 class ExtractOpFromElementwise final
1130 LogicalResult matchAndRewrite(vector::ExtractOp op,
1132 Operation *eltwise = op.getSource().getDefiningOp();
1137 isa<vector::FMAOp>(eltwise))
1151 if (!op.getDynamicPosition().empty())
1153 op,
"dynamic position not yet implemented");
1155 Type dstType = op.getType();
1164 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1165 mapping.
map(arg, newArg);
1179 static bool isSupportedMemSinkElementType(
Type type) {
1180 if (isa<IndexType>(type))
1201 class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1205 LogicalResult matchAndRewrite(vector::ExtractOp op,
1207 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1212 if (!loadOp->hasOneUse())
1215 VectorType loadVecType = loadOp.getVectorType();
1216 if (loadVecType.isScalable())
1218 "scalable vectors are not supported");
1220 MemRefType memType = loadOp.getMemRefType();
1224 if (!isSupportedMemSinkElementType(memType.getElementType()))
1227 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1231 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1232 int64_t finalRank = 0;
1234 finalRank = extractVecType.getRank();
1246 for (
auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1252 indices[i] = idxBuilderf.add(indices[i], offset);
1255 Value base = loadOp.getBase();
1256 if (extractVecType) {
1279 class StoreOpFromBroadcast final :
public OpRewritePattern<vector::StoreOp> {
1283 LogicalResult matchAndRewrite(vector::StoreOp op,
1285 VectorType vecType = op.getVectorType();
1286 if (vecType.isScalable())
1288 "scalable vectors are not supported");
1290 if (isa<VectorType>(op.getMemRefType().getElementType()))
1292 op,
"memrefs of vectors are not supported");
1294 if (vecType.getNumElements() != 1)
1296 op,
"only 1-element vectors are supported");
1298 Value toStore = op.getValueToStore();
1299 Value source = getBroadcastLikeSource(toStore);
1302 op,
"value to store is not from a broadcast");
1309 Value base = op.getBase();
1312 if (isa<VectorType>(source.
getType())) {
1332 bool force32BitVectorIndices, int64_t dim,
1341 if (dim == 0 && force32BitVectorIndices) {
1344 }
else if (dim == 0) {
1347 }
else if (force32BitVectorIndices) {
1349 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1352 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1354 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1358 Value ov = vector::BroadcastOp::create(rewriter, loc, indices.
getType(), o);
1359 indices = arith::AddIOp::create(rewriter, loc, ov, indices);
1364 vector::BroadcastOp::create(rewriter, loc, indices.
getType(), bound);
1365 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1369 template <
typename ConcreteOp>
1372 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1375 force32BitVectorIndices(enableIndexOpt) {}
1377 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1379 if (!xferOp.hasOutOfBoundsDim())
1382 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1386 VectorType vtp = xferOp.getVectorType();
1393 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1394 Value off = xferOp.getIndices()[lastIndex];
1397 Value b = arith::SubIOp::create(rewriter, loc, dim.
getType(), dim, off);
1398 Value mask = vector::CreateMaskOp::create(
1401 vtp.getScalableDims()),
1403 if (xferOp.getMask()) {
1405 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1409 xferOp.getMaskMutable().assign(mask);
1417 const bool force32BitVectorIndices;
1421 class VectorCreateMaskOpConversion
1424 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1425 bool enableIndexOpt,
1428 force32BitVectorIndices(enableIndexOpt) {}
1430 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1432 auto dstType = op.getType();
1433 if (cast<VectorType>(dstType).isScalable())
1435 int64_t rank = dstType.getRank();
1439 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1440 rank == 0 ? 0 : dstType.getDimSize(0),
1446 const bool force32BitVectorIndices;
1450 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1451 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1456 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1457 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1474 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1476 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1477 if (!vecType || !vecType.getElementType().isInteger(1))
1481 Value cond = selectOp.getCondition();
1482 if (isa<VectorType>(cond.
getType()))
1486 if (vecType.getRank() != 1 || vecType.isScalable())
1490 if (vecType.getShape()[0] != 1)
1493 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1494 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1498 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1499 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1503 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1525 static FailureOr<size_t>
1526 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1529 if (
failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1532 auto isUnitDim = [](VectorType type,
int dim) {
1533 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1540 int rankDiff = srcType.getRank() - vectorType.getRank();
1541 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1544 int dim = vectorType.getRank() - i - 1;
1545 if (srcStrides[dim + rankDiff] != 1 ||
1546 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1554 class DropInnerMostUnitDimsTransferRead
1558 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1561 if (readOp.getTransferRank() == 0)
1565 if (readOp.getMask())
1568 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1572 if (!readOp.getPermutationMap().isMinorIdentity())
1575 auto targetType = readOp.getVectorType();
1576 if (targetType.getRank() <= 1)
1579 FailureOr<size_t> maybeDimsToDrop =
1580 getTransferFoldableInnerUnitDims(srcType, targetType);
1581 if (
failed(maybeDimsToDrop))
1584 size_t dimsToDrop = maybeDimsToDrop.value();
1585 if (dimsToDrop == 0)
1588 auto inBounds = readOp.getInBoundsValues();
1589 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1590 if (llvm::is_contained(droppedInBounds,
false))
1593 auto resultTargetVecType =
1595 targetType.getElementType(),
1596 targetType.getScalableDims().drop_back(dimsToDrop));
1598 auto loc = readOp.getLoc();
1605 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1606 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1609 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1610 Value rankedReducedView =
1611 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1612 readOp.getBase(), offsets, sizes, strides);
1614 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1615 Value result = vector::TransferReadOp::create(
1616 rewriter, loc, resultTargetVecType, rankedReducedView,
1618 readOp.getPadding(),
1620 Value(), inBoundsAttr);
1645 class DropInnerMostUnitDimsTransferWrite
1649 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1652 if (writeOp.getTransferRank() == 0)
1656 if (writeOp.getMask())
1659 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1663 if (!writeOp.getPermutationMap().isMinorIdentity())
1666 auto targetType = writeOp.getVectorType();
1667 if (targetType.getRank() <= 1)
1670 FailureOr<size_t> maybeDimsToDrop =
1671 getTransferFoldableInnerUnitDims(srcType, targetType);
1672 if (
failed(maybeDimsToDrop))
1675 size_t dimsToDrop = maybeDimsToDrop.value();
1676 if (dimsToDrop == 0)
1679 auto inBounds = writeOp.getInBoundsValues();
1680 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1681 if (llvm::is_contained(droppedInBounds,
false))
1684 auto resultTargetVecType =
1686 targetType.getElementType(),
1687 targetType.getScalableDims().drop_back(dimsToDrop));
1696 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1697 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1700 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1702 Value rankedReducedView =
1703 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1704 writeOp.getBase(), offsets, sizes, strides);
1706 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1708 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1709 loc, resultTargetVecType, writeOp.getVector());
1711 writeOp, shapeCast, rankedReducedView,
1714 Value(), inBoundsAttr);
1722 struct CanonicalizeContractMatmulToMMT final
1726 using FilterConstraintType =
1727 std::function<LogicalResult(vector::ContractionOp op)>;
1730 FilterConstraintType constraint)
1732 filter(std::move(constraint)) {}
1734 LogicalResult matchAndRewrite(vector::ContractionOp op,
1740 Value lhs = op.getLhs();
1741 Value rhs = op.getRhs();
1742 Value res = op.getAcc();
1746 auto infer = [&](MapList m) {
1753 static constexpr std::array<int64_t, 2> perm = {1, 0};
1754 auto iteratorTypes = op.getIteratorTypes().getValue();
1756 if (iteratorTypes.size() != 3 ||
1763 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1764 if (maps == canonicalForm)
1769 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1770 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1772 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1773 VectorType newType =
1774 cast<VectorType>(trans.
getType())
1775 .clone(cast<VectorType>(mat.getType()).getElementType());
1776 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1778 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1780 vector::TransposeOp::create(rewriter, loc,
zext.getIn(), perm);
1781 VectorType newType =
1783 cast<VectorType>(mat.getType()).getElementType());
1784 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1786 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1789 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1790 rhs = createTranspose(rhs);
1791 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1792 lhs = createTranspose(lhs);
1793 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1794 rhs = createTranspose(rhs);
1795 lhs = createTranspose(lhs);
1796 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1797 std::swap(rhs, lhs);
1798 rhs = createTranspose(rhs);
1799 lhs = createTranspose(lhs);
1800 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1801 std::swap(rhs, lhs);
1802 rhs = createTranspose(rhs);
1803 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1804 std::swap(lhs, rhs);
1805 lhs = createTranspose(lhs);
1806 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1807 std::swap(lhs, rhs);
1813 op.getIteratorTypes());
1818 FilterConstraintType filter;
1838 template <
typename ExtOp>
1839 struct FoldArithExtIntoContractionOp
1843 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1846 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1847 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1849 if (!lhsDefOp || !rhsDefOp) {
1851 "no defining op on contract operands");
1855 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1856 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1857 contractOp.getIteratorTypesAttr());
1876 LogicalResult matchAndRewrite(vector::ReductionOp op,
1879 if (op.getKind() != vector::CombiningKind::ADD)
1883 Value acc = op.getAcc();
1890 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1891 if (!parentReduction)
1896 if (isa<IntegerType>(acc.
getType())) {
1898 loc, parentReduction.getVector(), op.getVector());
1900 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1904 parentReduction.getAcc());
1914 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1915 auto inVecShape = inVecTy.getShape();
1918 for (
auto [dim, isScalable] :
1919 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1920 if (dim == 1 && !isScalable)
1923 newShape.push_back(dim);
1924 newScalableDims.push_back(isScalable);
1927 if (newShape.empty()) {
1928 newShape.push_back(1);
1929 newScalableDims.push_back(
false);
1932 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1960 struct DropUnitDimFromElementwiseOps final
1963 LogicalResult matchAndRewrite(
Operation *op,
1969 if (!resultVectorType)
1976 if (!sourceVectorType)
1978 if (sourceVectorType.getRank() < 2)
1984 auto opVectorType = cast<VectorType>(operand.getType());
1985 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1986 if (newVType == opVectorType)
1989 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
1990 newOperands.push_back(opSC);
1993 VectorType newResultVectorType =
1994 dropNonScalableUnitDimFromType(resultVectorType);
1998 newResultVectorType, op->
getAttrs());
2027 struct DropUnitDimsFromTransposeOp final
2031 LogicalResult matchAndRewrite(vector::TransposeOp op,
2033 VectorType sourceType = op.getSourceVectorType();
2034 VectorType sourceTypeWithoutUnitDims =
2035 dropNonScalableUnitDimFromType(sourceType);
2037 if (sourceType == sourceTypeWithoutUnitDims)
2043 int64_t droppedDims = 0;
2045 droppedDimsBefore[i] = droppedDims;
2046 if (dim == std::make_tuple(1,
false))
2053 for (int64_t idx : perm) {
2054 if (sourceDims[idx] == std::make_tuple(1,
false))
2056 newPerm.push_back(idx - droppedDimsBefore[idx]);
2062 if (newPerm.empty()) {
2063 newPerm.push_back(0);
2068 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2069 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2071 auto transposeWithoutUnitDims =
2072 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2075 op, op.getResultVectorType(), transposeWithoutUnitDims);
2108 LogicalResult matchAndRewrite(scf::ForOp forOp,
2112 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
2113 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2117 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2118 if (vectorType == newVectorType)
2123 return vector::ShapeCastOp::create(b, loc, type, source);
2127 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2130 replacement, castFn));
2153 LogicalResult matchAndRewrite(vector::ReductionOp op,
2156 if (op.getKind() != vector::CombiningKind::ADD)
2159 Type elemType = op.getSourceVectorType().getElementType();
2162 if (!isa<FloatType>(elemType))
2175 auto newAdd = arith::AddFOp::create(rewriter, vAdd.
getLoc(),
2176 addLhs.getLhs(), vAdd.getRhs());
2193 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
2195 unsigned maxNumElementsToExtract,
2198 maxNumElementsToExtract(maxNumElementsToExtract) {}
2200 LogicalResult matchAndRewrite(vector::ReductionOp op,
2202 VectorType type = op.getSourceVectorType();
2203 if (type.isScalable() || op.isMasked())
2205 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2207 int64_t numElems = type.getNumElements();
2208 if (numElems > maxNumElementsToExtract) {
2210 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2211 "(max allowed: {1})",
2212 numElems, maxNumElementsToExtract));
2218 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2219 static_cast<int64_t
>(idx));
2221 Value res = extracted.front();
2222 for (
auto extractedElem : llvm::drop_begin(extracted))
2224 extractedElem, op.getFastmathAttr());
2225 if (
Value acc = op.getAcc())
2227 op.getFastmathAttr());
2234 unsigned maxNumElementsToExtract = 0;
2253 template <
typename MulOpType>
2254 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
2258 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2261 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2264 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2265 return srcType && srcType.getRank() != 2;
2268 LogicalResult matchAndRewrite(MulOpType mulOp,
2270 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2273 if (resType.getRank() != 2)
2278 auto matchOuterProduct =
2280 Value operandB) -> FailureOr<vector::OuterProductOp> {
2281 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2286 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2289 auto broadcastedLhs =
2290 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2291 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2294 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2295 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2298 return vector::OuterProductOp::create(
2299 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2300 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2303 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2304 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2307 maybeOuterP = matchOuterProduct(rhs, lhs);
2310 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2319 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2320 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2324 void mlir::vector::populateVectorMaskMaterializationPatterns(
2327 patterns.add<VectorCreateMaskOpConversion,
2328 MaterializeTransferMask<vector::TransferReadOp>,
2329 MaterializeTransferMask<vector::TransferWriteOp>>(
2330 patterns.getContext(), force32BitVectorIndices, benefit);
2334 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2336 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2337 DropUnitDimsFromTransposeOp>(
patterns.getContext(), benefit);
2340 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2342 patterns.add<BubbleDownVectorBitCastForExtract,
2343 BubbleDownBitCastForStridedSliceExtract,
2344 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2348 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2350 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2352 std::move(controlFn), benefit);
2357 std::function<LogicalResult(vector::ContractionOp)> constraint,
2359 patterns.add<CanonicalizeContractMatmulToMMT>(
patterns.getContext(), benefit,
2360 std::move(constraint));
2365 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2366 CombineContractABTranspose, CombineContractResultTranspose>(
2372 patterns.add<DropInnerMostUnitDimsTransferRead,
2373 DropInnerMostUnitDimsTransferWrite>(
patterns.getContext(),
2379 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2380 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2387 patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(
patterns.getContext(),
2391 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2398 void mlir::vector::populateBreakDownVectorReductionPatterns(
2402 maxNumElementsToExtract, benefit);
2407 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2408 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2416 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ArithBuilder specialized specifically for tensor/memref indexing calculations.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.