35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/FormatVariadic.h"
38 #define DEBUG_TYPE "vector-to-vector"
43 template <
typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
79 struct MultiReduceToContract
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
99 iteratorTypes.push_back(vector::IteratorType::reduction);
104 0, exprs, reduceOp.getContext());
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
139 struct CombineContractABTranspose final
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
151 for (
Value *operand : {&lhs, &rhs}) {
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
157 transposeOp.getPermutation(), contractOp.getContext());
159 *operand = transposeOp.getVector();
165 contractOp, lhs, rhs, contractOp.getAcc(),
203 struct CombineContractResultTranspose final
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
230 auto combinedResMap = resTMap.compose(contractMap);
237 maps.back() = combinedResMap;
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
273 FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278 Value lhs = contractOp.getLhs();
279 Value rhs = contractOp.getRhs();
282 for (
Value *operand : {&lhs, &rhs}) {
288 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
290 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast =
false;
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast =
true;
306 if (innerDimBroadcast)
311 bool nonUnitDimReductionBroadcast =
false;
312 for (int64_t i = 0; i < rankDiff; ++i) {
313 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
316 nonUnitDimReductionBroadcast =
true;
320 if (nonUnitDimReductionBroadcast)
326 map = broadcastMap.
compose(map);
342 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(i))
344 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit =
false;
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit =
true;
367 bool hasReductionIteratorApplyingOnBothSides =
false;
368 for (
unsigned i = 0; i < iterators.size(); ++i) {
372 hasReductionIteratorApplyingOnBothSides =
true;
376 if (!hasReductionIteratorApplyingOnBothSides)
384 Operation *newOp = vector::ContractionOp::create(
385 rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
390 if (isAnyUnusedDimNonUnit)
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
413 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
414 maskOpType, maskingOp.getMask())
423 struct CombineContractBroadcastMask
425 using MaskableOpRewritePattern::MaskableOpRewritePattern;
428 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
429 MaskingOpInterface maskingOp,
431 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
448 struct ReorderCastOpsOnBroadcast
452 LogicalResult matchAndRewrite(CastOpInterface op,
454 if (op->getNumOperands() != 1)
456 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
461 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
462 castResTy = vecTy.clone(castResTy);
464 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
465 bcastOp.getSource(), castResTy, op->getAttrs());
467 op, op->getResult(0).getType(), castOp->getResult(0));
486 struct ReorderElementwiseOpsOnTranspose final
489 LogicalResult matchAndRewrite(
Operation *op,
502 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
504 transposeMaps.push_back(transposeOp.getPermutation());
505 srcType = transposeOp.getSourceVectorType();
510 if (transposeMaps.empty())
515 if (!llvm::all_equal(transposeMaps))
523 auto order = transposeMaps.front();
525 for (
int i = 0, e = order.size(); i < e; ++i)
526 invOrder[order[i]] = i;
529 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
531 srcValues.push_back(transposeOp.getVector());
535 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
536 srcValues.push_back(vector::TransposeOp::create(
537 rewriter, operand.getLoc(), vectorType, operand, invOrder));
541 auto vectorType = srcType.clone(
548 transposeMaps.front());
555 return llvm::to_vector<4>(
556 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
557 [](IntegerAttr attr) { return attr.getInt(); }));
569 struct BubbleDownVectorBitCastForExtract
573 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
576 if (extractOp.getSourceVectorType().getRank() != 1)
579 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
583 VectorType castSrcType = castOp.getSourceVectorType();
584 VectorType castDstType = castOp.getResultVectorType();
585 assert(castSrcType.getRank() == castDstType.getRank());
590 if (castSrcType.getNumElements() == 1)
595 if (castSrcType.getNumElements() > castDstType.getNumElements())
598 unsigned expandRatio =
599 castDstType.getNumElements() / castSrcType.getNumElements();
602 auto mixedPos = extractOp.getMixedPosition();
603 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
605 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
610 Value packedValue = vector::ExtractOp::create(
611 rewriter, loc, castOp.getSource(), index / expandRatio);
613 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
615 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
620 VectorType packedType =
623 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
627 index % expandRatio);
644 struct BubbleDownBitCastForStridedSliceExtract
648 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
650 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
654 VectorType castSrcType = castOp.getSourceVectorType();
655 VectorType castDstType = castOp.getResultVectorType();
656 assert(castSrcType.getRank() == castDstType.getRank());
658 int64_t castSrcLastDim = castSrcType.getShape().back();
659 int64_t castDstLastDim = castDstType.getShape().back();
661 if (castSrcLastDim > castDstLastDim)
665 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666 [](
const APInt &val) { return !val.isOne(); }))
669 unsigned rank = extractOp.getSourceVectorType().getRank();
670 assert(castDstLastDim % castSrcLastDim == 0);
671 int64_t expandRatio = castDstLastDim / castSrcLastDim;
677 ArrayAttr newOffsets = extractOp.getOffsets();
678 if (newOffsets.size() == rank) {
680 if (offsets.back() % expandRatio != 0)
682 offsets.back() = offsets.back() / expandRatio;
687 ArrayAttr newSizes = extractOp.getSizes();
688 if (newSizes.size() == rank) {
690 if (sizes.back() % expandRatio != 0)
692 sizes.back() = sizes.back() / expandRatio;
697 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698 dims.back() = dims.back() / expandRatio;
699 VectorType newExtractType =
702 auto newExtractOp = vector::ExtractStridedSliceOp::create(
703 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
704 newOffsets, newSizes, extractOp.getStrides());
707 extractOp, extractOp.getType(), newExtractOp);
723 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
726 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
728 VectorType castSrcType = bitcastOp.getSourceVectorType();
729 VectorType castDstType = bitcastOp.getResultVectorType();
732 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733 castDstType.isScalable())
736 int64_t castSrcLastDim = castSrcType.getShape().back();
737 int64_t castDstLastDim = castDstType.getShape().back();
738 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
740 if (isNumElemsShrink) {
741 assert(castSrcLastDim % castDstLastDim == 0);
742 ratio = castSrcLastDim / castDstLastDim;
744 assert(castDstLastDim % castSrcLastDim == 0);
745 ratio = castDstLastDim / castSrcLastDim;
748 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
753 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
760 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761 VectorType newCastSrcType =
764 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
765 insertOp.getValueToStore());
769 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770 VectorType newCastDstType =
774 auto newCastDstOp = vector::BitCastOp::create(
775 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
779 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
795 struct BubbleUpBitCastForStridedSliceInsert
799 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
801 VectorType castSrcType = bitcastOp.getSourceVectorType();
802 VectorType castDstType = bitcastOp.getResultVectorType();
803 assert(castSrcType.getRank() == castDstType.getRank());
805 if (castSrcType.getRank() == 0)
808 int64_t castSrcLastDim = castSrcType.getShape().back();
809 int64_t castDstLastDim = castDstType.getShape().back();
811 if (castSrcLastDim < castDstLastDim)
814 assert(castSrcLastDim % castDstLastDim == 0);
815 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
818 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
823 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824 [](
const APInt &val) { return !val.isOne(); }))
827 unsigned rank = insertOp.getSourceVectorType().getRank();
830 if (rank != insertOp.getDestVectorType().getRank())
834 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835 unsigned destinationWidth =
836 castDstType.getElementType().getIntOrFloatBitWidth();
837 unsigned numElements = destinationWidth / sourceWidth;
838 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
841 ArrayAttr newOffsets = insertOp.getOffsets();
842 assert(newOffsets.size() == rank);
844 if (offsets.back() % shrinkRatio != 0)
846 offsets.back() = offsets.back() / shrinkRatio;
850 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
851 srcDims.back() = srcDims.back() / shrinkRatio;
852 VectorType newCastSrcType =
856 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
857 insertOp.getValueToStore());
860 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
861 dstDims.back() = dstDims.back() / shrinkRatio;
862 VectorType newCastDstType =
865 auto newCastDstOp = vector::BitCastOp::create(
866 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
869 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
870 insertOp.getStrides());
894 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
899 std::function<
bool(vector::BitCastOp)> controlFn,
903 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
906 if (controlFn && !controlFn(bitcastOp))
909 VectorType castSrcType = bitcastOp.getSourceVectorType();
910 VectorType castDstType = bitcastOp.getResultVectorType();
911 assert(castSrcType.getRank() == castDstType.getRank());
916 if (castSrcType.isScalable())
918 "Scalable vectors are not supported");
921 if (castSrcType.getRank() != 1)
924 int64_t castSrcLastDim = castSrcType.getShape().back();
925 int64_t castDstLastDim = castDstType.getShape().back();
927 if (castSrcLastDim < castDstLastDim)
930 assert(castSrcLastDim % castDstLastDim == 0);
931 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
933 if (castSrcLastDim == shrinkRatio)
937 Type elemType = castDstType.getElementType();
940 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
942 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
946 VectorType newCastDstType =
948 castDstType.getElementType());
950 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
951 Value extracted = ExtractStridedSliceOp::create(
952 rewriter, loc, bitcastOp.getSource(),
955 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
956 res = InsertStridedSliceOp::create(
957 rewriter, loc, bitcast, res,
965 std::function<bool(BitCastOp)> controlFn;
968 static bool haveSameShapeAndScaling(
Type t,
Type u) {
969 auto tVec = dyn_cast<VectorType>(t);
970 auto uVec = dyn_cast<VectorType>(u);
977 return tVec.getShape() == uVec.getShape() &&
978 tVec.getScalableDims() == uVec.getScalableDims();
983 static Type cloneOrReplace(
Type type,
Type newElementType) {
984 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
985 return shapedType.clone(newElementType);
987 return newElementType;
992 static Value getBroadcastLikeSource(
Value value) {
998 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1001 if (
auto splat = dyn_cast<vector::SplatOp>(op))
1002 return splat.getInput();
1023 struct ReorderElementwiseOpsOnBroadcast final
1026 LogicalResult matchAndRewrite(
Operation *op,
1035 op,
"Op doesn't have ElementwiseMappableTraits");
1038 if (isa<vector::FMAOp>(op)) {
1041 "Op only accepts vector types - not supported as broadcast source "
1042 "might be a scalar");
1045 Type resultElemType = resultType.getElementType();
1050 Operation *definingOp = operand.getDefiningOp();
1055 splatSource = getBroadcastLikeSource(operand);
1060 Type unbroadcastResultType =
1061 cloneOrReplace(splatSource.
getType(), resultElemType);
1068 if (auto source = getBroadcastLikeSource(val))
1069 return haveSameShapeAndScaling(source.getType(),
1070 splatSource.getType());
1071 SplatElementsAttr splatConst;
1072 return matchPattern(val, m_Constant(&splatConst));
1076 "not all operands are constants or broadcasts from the same type");
1087 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1088 if (
auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1095 rewriter, newConst, newType, operand.getLoc());
1096 srcValues.push_back(newConstOp->
getResult(0));
1098 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1105 unbroadcastResultType, op->
getAttrs());
1109 op, resultType, elementwiseOp->
getResults());
1131 class ExtractOpFromElementwise final
1136 LogicalResult matchAndRewrite(vector::ExtractOp op,
1138 Operation *eltwise = op.getVector().getDefiningOp();
1143 isa<vector::FMAOp>(eltwise))
1157 if (!op.getDynamicPosition().empty())
1159 op,
"dynamic position not yet implemented");
1161 Type dstType = op.getType();
1170 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1171 mapping.
map(arg, newArg);
1185 static bool isSupportedMemSinkElementType(
Type type) {
1186 if (isa<IndexType>(type))
1207 class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1211 LogicalResult matchAndRewrite(vector::ExtractOp op,
1213 auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1218 if (!loadOp->hasOneUse())
1221 VectorType loadVecType = loadOp.getVectorType();
1222 if (loadVecType.isScalable())
1224 "scalable vectors are not supported");
1226 MemRefType memType = loadOp.getMemRefType();
1230 if (!isSupportedMemSinkElementType(memType.getElementType()))
1233 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1237 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1238 int64_t finalRank = 0;
1240 finalRank = extractVecType.getRank();
1252 for (
auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1258 indices[i] = idxBuilderf.add(indices[i], offset);
1261 Value base = loadOp.getBase();
1262 if (extractVecType) {
1285 class StoreOpFromSplatOrBroadcast final
1290 LogicalResult matchAndRewrite(vector::StoreOp op,
1292 VectorType vecType = op.getVectorType();
1293 if (vecType.isScalable())
1295 "scalable vectors are not supported");
1297 if (isa<VectorType>(op.getMemRefType().getElementType()))
1299 op,
"memrefs of vectors are not supported");
1301 if (vecType.getNumElements() != 1)
1303 op,
"only 1-element vectors are supported");
1305 Value toStore = op.getValueToStore();
1306 Value source = getBroadcastLikeSource(toStore);
1309 op,
"value to store is not from a broadcast");
1316 Value base = op.getBase();
1319 if (isa<VectorType>(source.
getType())) {
1339 bool force32BitVectorIndices, int64_t dim,
1348 if (dim == 0 && force32BitVectorIndices) {
1351 }
else if (dim == 0) {
1354 }
else if (force32BitVectorIndices) {
1356 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1359 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1361 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1365 Value ov = vector::BroadcastOp::create(rewriter, loc, indices.
getType(), o);
1366 indices = arith::AddIOp::create(rewriter, loc, ov, indices);
1371 vector::BroadcastOp::create(rewriter, loc, indices.
getType(), bound);
1372 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1376 template <
typename ConcreteOp>
1379 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1382 force32BitVectorIndices(enableIndexOpt) {}
1384 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1386 if (!xferOp.hasOutOfBoundsDim())
1389 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1393 VectorType vtp = xferOp.getVectorType();
1400 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1401 Value off = xferOp.getIndices()[lastIndex];
1404 Value b = arith::SubIOp::create(rewriter, loc, dim.
getType(), dim, off);
1405 Value mask = vector::CreateMaskOp::create(
1408 vtp.getScalableDims()),
1410 if (xferOp.getMask()) {
1412 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1416 xferOp.getMaskMutable().assign(mask);
1424 const bool force32BitVectorIndices;
1428 class VectorCreateMaskOpConversion
1431 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1432 bool enableIndexOpt,
1435 force32BitVectorIndices(enableIndexOpt) {}
1437 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1439 auto dstType = op.getType();
1440 if (cast<VectorType>(dstType).isScalable())
1442 int64_t rank = dstType.getRank();
1446 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1447 rank == 0 ? 0 : dstType.getDimSize(0),
1453 const bool force32BitVectorIndices;
1457 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1458 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1463 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1464 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1481 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1483 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1484 if (!vecType || !vecType.getElementType().isInteger(1))
1488 Value cond = selectOp.getCondition();
1489 if (isa<VectorType>(cond.
getType()))
1493 if (vecType.getRank() != 1 || vecType.isScalable())
1497 if (vecType.getShape()[0] != 1)
1500 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1501 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1505 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1506 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1510 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1532 static FailureOr<size_t>
1533 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1536 if (
failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1539 auto isUnitDim = [](VectorType type,
int dim) {
1540 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1547 int rankDiff = srcType.getRank() - vectorType.getRank();
1548 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1551 int dim = vectorType.getRank() - i - 1;
1552 if (srcStrides[dim + rankDiff] != 1 ||
1553 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1561 class DropInnerMostUnitDimsTransferRead
1565 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1568 if (readOp.getTransferRank() == 0)
1572 if (readOp.getMask())
1575 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1579 if (!readOp.getPermutationMap().isMinorIdentity())
1582 auto targetType = readOp.getVectorType();
1583 if (targetType.getRank() <= 1)
1586 FailureOr<size_t> maybeDimsToDrop =
1587 getTransferFoldableInnerUnitDims(srcType, targetType);
1588 if (
failed(maybeDimsToDrop))
1591 size_t dimsToDrop = maybeDimsToDrop.value();
1592 if (dimsToDrop == 0)
1595 auto inBounds = readOp.getInBoundsValues();
1596 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1597 if (llvm::is_contained(droppedInBounds,
false))
1600 auto resultTargetVecType =
1602 targetType.getElementType(),
1603 targetType.getScalableDims().drop_back(dimsToDrop));
1605 auto loc = readOp.getLoc();
1612 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1613 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1616 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1617 Value rankedReducedView =
1618 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1619 readOp.getBase(), offsets, sizes, strides);
1621 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1622 Value result = vector::TransferReadOp::create(
1623 rewriter, loc, resultTargetVecType, rankedReducedView,
1625 readOp.getPadding(),
1627 Value(), inBoundsAttr);
1652 class DropInnerMostUnitDimsTransferWrite
1656 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1659 if (writeOp.getTransferRank() == 0)
1663 if (writeOp.getMask())
1666 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1670 if (!writeOp.getPermutationMap().isMinorIdentity())
1673 auto targetType = writeOp.getVectorType();
1674 if (targetType.getRank() <= 1)
1677 FailureOr<size_t> maybeDimsToDrop =
1678 getTransferFoldableInnerUnitDims(srcType, targetType);
1679 if (
failed(maybeDimsToDrop))
1682 size_t dimsToDrop = maybeDimsToDrop.value();
1683 if (dimsToDrop == 0)
1686 auto inBounds = writeOp.getInBoundsValues();
1687 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1688 if (llvm::is_contained(droppedInBounds,
false))
1691 auto resultTargetVecType =
1693 targetType.getElementType(),
1694 targetType.getScalableDims().drop_back(dimsToDrop));
1703 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1704 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1707 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1709 Value rankedReducedView =
1710 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1711 writeOp.getBase(), offsets, sizes, strides);
1713 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1715 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1716 loc, resultTargetVecType, writeOp.getVector());
1718 writeOp, shapeCast, rankedReducedView,
1721 Value(), inBoundsAttr);
1729 struct CanonicalizeContractMatmulToMMT final
1733 using FilterConstraintType =
1734 std::function<LogicalResult(vector::ContractionOp op)>;
1737 FilterConstraintType constraint)
1739 filter(std::move(constraint)) {}
1741 LogicalResult matchAndRewrite(vector::ContractionOp op,
1747 Value lhs = op.getLhs();
1748 Value rhs = op.getRhs();
1749 Value res = op.getAcc();
1753 auto infer = [&](MapList m) {
1760 static constexpr std::array<int64_t, 2> perm = {1, 0};
1761 auto iteratorTypes = op.getIteratorTypes().getValue();
1763 if (iteratorTypes.size() != 3 ||
1770 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1771 if (maps == canonicalForm)
1776 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1777 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1779 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1780 VectorType newType =
1781 cast<VectorType>(trans.
getType())
1782 .clone(cast<VectorType>(mat.getType()).getElementType());
1783 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1785 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1787 vector::TransposeOp::create(rewriter, loc,
zext.getIn(), perm);
1788 VectorType newType =
1790 cast<VectorType>(mat.getType()).getElementType());
1791 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1793 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1796 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1797 rhs = createTranspose(rhs);
1798 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1799 lhs = createTranspose(lhs);
1800 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1801 rhs = createTranspose(rhs);
1802 lhs = createTranspose(lhs);
1803 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1804 std::swap(rhs, lhs);
1805 rhs = createTranspose(rhs);
1806 lhs = createTranspose(lhs);
1807 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1808 std::swap(rhs, lhs);
1809 rhs = createTranspose(rhs);
1810 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1811 std::swap(lhs, rhs);
1812 lhs = createTranspose(lhs);
1813 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1814 std::swap(lhs, rhs);
1820 op.getIteratorTypes());
1825 FilterConstraintType filter;
1845 template <
typename ExtOp>
1846 struct FoldArithExtIntoContractionOp
1850 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1853 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1854 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1856 if (!lhsDefOp || !rhsDefOp) {
1858 "no defining op on contract operands");
1862 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1863 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1864 contractOp.getIteratorTypesAttr());
1883 LogicalResult matchAndRewrite(vector::ReductionOp op,
1886 if (op.getKind() != vector::CombiningKind::ADD)
1890 Value acc = op.getAcc();
1897 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1898 if (!parentReduction)
1903 if (isa<IntegerType>(acc.
getType())) {
1905 loc, parentReduction.getVector(), op.getVector());
1907 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1911 parentReduction.getAcc());
1921 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1922 auto inVecShape = inVecTy.getShape();
1925 for (
auto [dim, isScalable] :
1926 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1927 if (dim == 1 && !isScalable)
1930 newShape.push_back(dim);
1931 newScalableDims.push_back(isScalable);
1934 if (newShape.empty()) {
1935 newShape.push_back(1);
1936 newScalableDims.push_back(
false);
1939 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1967 struct DropUnitDimFromElementwiseOps final
1970 LogicalResult matchAndRewrite(
Operation *op,
1976 if (!resultVectorType)
1983 if (!sourceVectorType)
1985 if (sourceVectorType.getRank() < 2)
1991 auto opVectorType = cast<VectorType>(operand.getType());
1992 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1993 if (newVType == opVectorType)
1996 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
1997 newOperands.push_back(opSC);
2000 VectorType newResultVectorType =
2001 dropNonScalableUnitDimFromType(resultVectorType);
2005 newResultVectorType, op->
getAttrs());
2034 struct DropUnitDimsFromTransposeOp final
2038 LogicalResult matchAndRewrite(vector::TransposeOp op,
2040 VectorType sourceType = op.getSourceVectorType();
2041 VectorType sourceTypeWithoutUnitDims =
2042 dropNonScalableUnitDimFromType(sourceType);
2044 if (sourceType == sourceTypeWithoutUnitDims)
2050 int64_t droppedDims = 0;
2052 droppedDimsBefore[i] = droppedDims;
2053 if (dim == std::make_tuple(1,
false))
2060 for (int64_t idx : perm) {
2061 if (sourceDims[idx] == std::make_tuple(1,
false))
2063 newPerm.push_back(idx - droppedDimsBefore[idx]);
2069 if (newPerm.empty()) {
2070 newPerm.push_back(0);
2075 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2076 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2078 auto transposeWithoutUnitDims =
2079 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2082 op, op.getResultVectorType(), transposeWithoutUnitDims);
2115 LogicalResult matchAndRewrite(scf::ForOp forOp,
2119 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
2120 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2124 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2125 if (vectorType == newVectorType)
2130 return vector::ShapeCastOp::create(b, loc, type, source);
2134 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2137 replacement, castFn));
2160 LogicalResult matchAndRewrite(vector::ReductionOp op,
2163 if (op.getKind() != vector::CombiningKind::ADD)
2166 Type elemType = op.getSourceVectorType().getElementType();
2169 if (!isa<FloatType>(elemType))
2182 auto newAdd = arith::AddFOp::create(rewriter, vAdd.
getLoc(),
2183 addLhs.getLhs(), vAdd.getRhs());
2200 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
2202 unsigned maxNumElementsToExtract,
2205 maxNumElementsToExtract(maxNumElementsToExtract) {}
2207 LogicalResult matchAndRewrite(vector::ReductionOp op,
2209 VectorType type = op.getSourceVectorType();
2210 if (type.isScalable() || op.isMasked())
2212 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2214 int64_t numElems = type.getNumElements();
2215 if (numElems > maxNumElementsToExtract) {
2217 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2218 "(max allowed: {1})",
2219 numElems, maxNumElementsToExtract));
2225 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2226 static_cast<int64_t
>(idx));
2228 Value res = extracted.front();
2229 for (
auto extractedElem : llvm::drop_begin(extracted))
2231 extractedElem, op.getFastmathAttr());
2232 if (
Value acc = op.getAcc())
2234 op.getFastmathAttr());
2241 unsigned maxNumElementsToExtract = 0;
2260 template <
typename MulOpType>
2261 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
2265 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2268 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2271 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2272 return srcType && srcType.getRank() != 2;
2275 LogicalResult matchAndRewrite(MulOpType mulOp,
2277 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2280 if (resType.getRank() != 2)
2285 auto matchOuterProduct =
2287 Value operandB) -> FailureOr<vector::OuterProductOp> {
2288 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2293 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2296 auto broadcastedLhs =
2297 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2298 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2301 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2302 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2305 return vector::OuterProductOp::create(
2306 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2307 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2310 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2311 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2314 maybeOuterP = matchOuterProduct(rhs, lhs);
2317 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2326 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2327 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2331 void mlir::vector::populateVectorMaskMaterializationPatterns(
2334 patterns.add<VectorCreateMaskOpConversion,
2335 MaterializeTransferMask<vector::TransferReadOp>,
2336 MaterializeTransferMask<vector::TransferWriteOp>>(
2337 patterns.getContext(), force32BitVectorIndices, benefit);
2341 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2343 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2344 DropUnitDimsFromTransposeOp>(
patterns.getContext(), benefit);
2347 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2349 patterns.add<BubbleDownVectorBitCastForExtract,
2350 BubbleDownBitCastForStridedSliceExtract,
2351 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2355 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2357 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2359 std::move(controlFn), benefit);
2364 std::function<LogicalResult(vector::ContractionOp)> constraint,
2366 patterns.add<CanonicalizeContractMatmulToMMT>(
patterns.getContext(), benefit,
2367 std::move(constraint));
2372 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2373 CombineContractABTranspose, CombineContractResultTranspose>(
2379 patterns.add<DropInnerMostUnitDimsTransferRead,
2380 DropInnerMostUnitDimsTransferWrite>(
patterns.getContext(),
2386 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2387 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2394 patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2398 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2405 void mlir::vector::populateBreakDownVectorReductionPatterns(
2409 maxNumElementsToExtract, benefit);
2414 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2415 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2423 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ArithBuilder specialized specifically for tensor/memref indexing calculations.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.