35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/FormatVariadic.h"
38 #define DEBUG_TYPE "vector-to-vector"
43 template <
typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
79 struct MultiReduceToContract
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
99 iteratorTypes.push_back(vector::IteratorType::reduction);
104 0, exprs, reduceOp.getContext());
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
139 struct CombineContractABTranspose final
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
151 for (
Value *operand : {&lhs, &rhs}) {
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
157 transposeOp.getPermutation(), contractOp.getContext());
159 *operand = transposeOp.getVector();
165 contractOp, lhs, rhs, contractOp.getAcc(),
203 struct CombineContractResultTranspose final
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
230 auto combinedResMap = resTMap.compose(contractMap);
237 maps.back() = combinedResMap;
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
273 FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278 Value lhs = contractOp.getLhs();
279 Value rhs = contractOp.getRhs();
282 for (
Value *operand : {&lhs, &rhs}) {
288 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
290 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast =
false;
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast =
true;
306 if (innerDimBroadcast)
311 bool nonUnitDimReductionBroadcast =
false;
312 for (int64_t i = 0; i < rankDiff; ++i) {
313 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
316 nonUnitDimReductionBroadcast =
true;
320 if (nonUnitDimReductionBroadcast)
326 map = broadcastMap.
compose(map);
342 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(i))
344 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit =
false;
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (
unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit =
true;
367 bool hasReductionIteratorApplyingOnBothSides =
false;
368 for (
unsigned i = 0; i < iterators.size(); ++i) {
372 hasReductionIteratorApplyingOnBothSides =
true;
376 if (!hasReductionIteratorApplyingOnBothSides)
385 contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
390 if (isAnyUnusedDimNonUnit)
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
414 .
create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
424 struct CombineContractBroadcastMask
426 using MaskableOpRewritePattern::MaskableOpRewritePattern;
429 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
430 MaskingOpInterface maskingOp,
432 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
449 struct ReorderCastOpsOnBroadcast
453 LogicalResult matchAndRewrite(CastOpInterface op,
455 if (op->getNumOperands() != 1)
457 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
462 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
463 castResTy = vecTy.clone(castResTy);
465 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
466 bcastOp.getSource(), castResTy, op->getAttrs());
468 op, op->getResult(0).getType(), castOp->getResult(0));
487 struct ReorderElementwiseOpsOnTranspose final
490 LogicalResult matchAndRewrite(
Operation *op,
503 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
505 transposeMaps.push_back(transposeOp.getPermutation());
506 srcType = transposeOp.getSourceVectorType();
511 if (transposeMaps.empty())
516 if (!llvm::all_equal(transposeMaps))
524 auto order = transposeMaps.front();
526 for (
int i = 0, e = order.size(); i < e; ++i)
527 invOrder[order[i]] = i;
530 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
532 srcValues.push_back(transposeOp.getVector());
536 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
537 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
538 operand.getLoc(), vectorType, operand, invOrder));
542 auto vectorType = srcType.clone(
549 transposeMaps.front());
556 return llvm::to_vector<4>(
557 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
558 [](IntegerAttr attr) { return attr.getInt(); }));
570 struct BubbleDownVectorBitCastForExtract
574 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
577 if (extractOp.getSourceVectorType().getRank() != 1)
580 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
584 VectorType castSrcType = castOp.getSourceVectorType();
585 VectorType castDstType = castOp.getResultVectorType();
586 assert(castSrcType.getRank() == castDstType.getRank());
591 if (castSrcType.getNumElements() == 1)
596 if (castSrcType.getNumElements() > castDstType.getNumElements())
599 unsigned expandRatio =
600 castDstType.getNumElements() / castSrcType.getNumElements();
603 auto mixedPos = extractOp.getMixedPosition();
604 if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
606 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
611 Value packedValue = rewriter.
create<vector::ExtractOp>(
612 loc, castOp.getSource(), index / expandRatio);
615 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
616 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
621 VectorType packedType =
624 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
628 index % expandRatio);
645 struct BubbleDownBitCastForStridedSliceExtract
649 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
651 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
655 VectorType castSrcType = castOp.getSourceVectorType();
656 VectorType castDstType = castOp.getResultVectorType();
657 assert(castSrcType.getRank() == castDstType.getRank());
659 int64_t castSrcLastDim = castSrcType.getShape().back();
660 int64_t castDstLastDim = castDstType.getShape().back();
662 if (castSrcLastDim > castDstLastDim)
666 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
667 [](
const APInt &val) { return !val.isOne(); }))
670 unsigned rank = extractOp.getSourceVectorType().getRank();
671 assert(castDstLastDim % castSrcLastDim == 0);
672 int64_t expandRatio = castDstLastDim / castSrcLastDim;
678 ArrayAttr newOffsets = extractOp.getOffsets();
679 if (newOffsets.size() == rank) {
681 if (offsets.back() % expandRatio != 0)
683 offsets.back() = offsets.back() / expandRatio;
688 ArrayAttr newSizes = extractOp.getSizes();
689 if (newSizes.size() == rank) {
691 if (sizes.back() % expandRatio != 0)
693 sizes.back() = sizes.back() / expandRatio;
698 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
699 dims.back() = dims.back() / expandRatio;
700 VectorType newExtractType =
703 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
704 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
705 newSizes, extractOp.getStrides());
708 extractOp, extractOp.getType(), newExtractOp);
724 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
727 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
729 VectorType castSrcType = bitcastOp.getSourceVectorType();
730 VectorType castDstType = bitcastOp.getResultVectorType();
733 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
734 castDstType.isScalable())
737 int64_t castSrcLastDim = castSrcType.getShape().back();
738 int64_t castDstLastDim = castDstType.getShape().back();
739 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
741 if (isNumElemsShrink) {
742 assert(castSrcLastDim % castDstLastDim == 0);
743 ratio = castSrcLastDim / castDstLastDim;
745 assert(castDstLastDim % castSrcLastDim == 0);
746 ratio = castDstLastDim / castSrcLastDim;
749 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
754 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
761 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
762 VectorType newCastSrcType =
764 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
765 bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
769 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770 VectorType newCastDstType =
774 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
775 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
779 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
795 struct BubbleUpBitCastForStridedSliceInsert
799 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
801 VectorType castSrcType = bitcastOp.getSourceVectorType();
802 VectorType castDstType = bitcastOp.getResultVectorType();
803 assert(castSrcType.getRank() == castDstType.getRank());
805 if (castSrcType.getRank() == 0)
808 int64_t castSrcLastDim = castSrcType.getShape().back();
809 int64_t castDstLastDim = castDstType.getShape().back();
811 if (castSrcLastDim < castDstLastDim)
814 assert(castSrcLastDim % castDstLastDim == 0);
815 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
818 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
823 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824 [](
const APInt &val) { return !val.isOne(); }))
827 unsigned rank = insertOp.getSourceVectorType().getRank();
830 if (rank != insertOp.getDestVectorType().getRank())
834 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835 unsigned destinationWidth =
836 castDstType.getElementType().getIntOrFloatBitWidth();
837 unsigned numElements = destinationWidth / sourceWidth;
838 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
841 ArrayAttr newOffsets = insertOp.getOffsets();
842 assert(newOffsets.size() == rank);
844 if (offsets.back() % shrinkRatio != 0)
846 offsets.back() = offsets.back() / shrinkRatio;
850 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
851 srcDims.back() = srcDims.back() / shrinkRatio;
852 VectorType newCastSrcType =
855 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
856 bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
859 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
860 dstDims.back() = dstDims.back() / shrinkRatio;
861 VectorType newCastDstType =
864 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
865 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
868 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
869 insertOp.getStrides());
893 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
898 std::function<
bool(vector::BitCastOp)> controlFn,
902 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
905 if (controlFn && !controlFn(bitcastOp))
908 VectorType castSrcType = bitcastOp.getSourceVectorType();
909 VectorType castDstType = bitcastOp.getResultVectorType();
910 assert(castSrcType.getRank() == castDstType.getRank());
915 if (castSrcType.isScalable())
917 "Scalable vectors are not supported");
920 if (castSrcType.getRank() != 1)
923 int64_t castSrcLastDim = castSrcType.getShape().back();
924 int64_t castDstLastDim = castDstType.getShape().back();
926 if (castSrcLastDim < castDstLastDim)
929 assert(castSrcLastDim % castDstLastDim == 0);
930 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
932 if (castSrcLastDim == shrinkRatio)
936 Type elemType = castDstType.getElementType();
941 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
945 VectorType newCastDstType =
947 castDstType.getElementType());
949 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
950 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
952 sliceShape, strides);
954 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
955 res = rewriter.
create<InsertStridedSliceOp>(
964 std::function<bool(BitCastOp)> controlFn;
983 struct ReorderElementwiseOpsOnBroadcast final
986 LogicalResult matchAndRewrite(
Operation *op,
994 op,
"Op doesn't have ElementwiseMappableTraits");
999 "result and operand type mismatch");
1000 if (isa<vector::FMAOp>(op)) {
1003 "Op only accepts vector types - not supported as broadcast source "
1004 "might be a scalar");
1009 if (!lhsBcastOrSplat ||
1010 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1012 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
1019 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1021 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1022 auto splat = val.getDefiningOp<vector::SplatOp>();
1024 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1034 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1040 lhsBcastOrSplatType, op->
getAttrs());
1045 op, vectorType, elementwiseOp->
getResults());
1067 class ExtractOpFromElementwise final
1072 LogicalResult matchAndRewrite(vector::ExtractOp op,
1074 Operation *eltwise = op.getVector().getDefiningOp();
1079 isa<vector::FMAOp>(eltwise))
1091 Type dstType = op.getType();
1100 Value newArg = rewriter.
create<vector::ExtractOp>(loc, arg, pos);
1101 mapping.
map(arg, newArg);
1115 static bool isSupportedMemSinkElementType(
Type type) {
1116 if (isa<IndexType>(type))
1137 class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1141 LogicalResult matchAndRewrite(vector::ExtractOp op,
1143 auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1148 if (!loadOp->hasOneUse())
1151 VectorType loadVecType = loadOp.getVectorType();
1152 if (loadVecType.isScalable())
1154 "scalable vectors are not supported");
1156 MemRefType memType = loadOp.getMemRefType();
1160 if (!isSupportedMemSinkElementType(memType.getElementType()))
1163 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1167 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1168 int64_t finalRank = 0;
1170 finalRank = extractVecType.getRank();
1182 for (
auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1188 indices[i] = idxBuilderf.add(indices[i], offset);
1191 Value base = loadOp.getBase();
1192 if (extractVecType) {
1215 class StoreOpFromSplatOrBroadcast final
1220 LogicalResult matchAndRewrite(vector::StoreOp op,
1222 VectorType vecType = op.getVectorType();
1223 if (vecType.isScalable())
1225 "scalable vectors are not supported");
1227 if (isa<VectorType>(op.getMemRefType().getElementType()))
1229 op,
"memrefs of vectors are not supported");
1231 if (vecType.getNumElements() != 1)
1233 op,
"only 1-element vectors are supported");
1235 Operation *splat = op.getValueToStore().getDefiningOp();
1236 if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1244 Value base = op.getBase();
1247 if (isa<VectorType>(source.
getType())) {
1267 bool force32BitVectorIndices, int64_t dim,
1276 if (dim == 0 && force32BitVectorIndices) {
1279 }
else if (dim == 0) {
1282 }
else if (force32BitVectorIndices) {
1284 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1287 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1289 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1294 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1299 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1300 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1304 template <
typename ConcreteOp>
1307 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1310 force32BitVectorIndices(enableIndexOpt) {}
1312 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1314 if (!xferOp.hasOutOfBoundsDim())
1317 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1321 VectorType vtp = xferOp.getVectorType();
1328 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1329 Value off = xferOp.getIndices()[lastIndex];
1333 Value mask = rewriter.
create<vector::CreateMaskOp>(
1336 vtp.getScalableDims()),
1338 if (xferOp.getMask()) {
1340 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1344 xferOp.getMaskMutable().assign(mask);
1352 const bool force32BitVectorIndices;
1356 class VectorCreateMaskOpConversion
1359 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1360 bool enableIndexOpt,
1363 force32BitVectorIndices(enableIndexOpt) {}
1365 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1367 auto dstType = op.getType();
1368 if (cast<VectorType>(dstType).isScalable())
1370 int64_t rank = dstType.getRank();
1374 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1375 rank == 0 ? 0 : dstType.getDimSize(0),
1381 const bool force32BitVectorIndices;
1385 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1386 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1391 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1392 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1409 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1411 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1412 if (!vecType || !vecType.getElementType().isInteger(1))
1416 Value cond = selectOp.getCondition();
1417 if (isa<VectorType>(cond.
getType()))
1421 if (vecType.getRank() != 1 || vecType.isScalable())
1425 if (vecType.getShape()[0] != 1)
1428 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1429 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1433 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1434 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1438 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1460 static FailureOr<size_t>
1461 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1464 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1467 auto isUnitDim = [](VectorType type,
int dim) {
1468 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1475 int rankDiff = srcType.getRank() - vectorType.getRank();
1476 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1479 int dim = vectorType.getRank() - i - 1;
1480 if (srcStrides[dim + rankDiff] != 1 ||
1481 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1489 class DropInnerMostUnitDimsTransferRead
1493 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1496 if (readOp.getTransferRank() == 0)
1500 if (readOp.getMask())
1503 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1507 if (!readOp.getPermutationMap().isMinorIdentity())
1510 auto targetType = readOp.getVectorType();
1511 if (targetType.getRank() <= 1)
1514 FailureOr<size_t> maybeDimsToDrop =
1515 getTransferFoldableInnerUnitDims(srcType, targetType);
1516 if (failed(maybeDimsToDrop))
1519 size_t dimsToDrop = maybeDimsToDrop.value();
1520 if (dimsToDrop == 0)
1523 auto inBounds = readOp.getInBoundsValues();
1524 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1525 if (llvm::is_contained(droppedInBounds,
false))
1528 auto resultTargetVecType =
1530 targetType.getElementType(),
1531 targetType.getScalableDims().drop_back(dimsToDrop));
1533 auto loc = readOp.getLoc();
1540 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1541 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1544 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1545 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1546 loc, resultMemrefType, readOp.getBase(), offsets, sizes, strides);
1548 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1549 Value result = rewriter.
create<vector::TransferReadOp>(
1550 loc, resultTargetVecType, rankedReducedView,
1552 readOp.getPadding(),
1554 Value(), inBoundsAttr);
1579 class DropInnerMostUnitDimsTransferWrite
1583 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1586 if (writeOp.getTransferRank() == 0)
1590 if (writeOp.getMask())
1593 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1597 if (!writeOp.getPermutationMap().isMinorIdentity())
1600 auto targetType = writeOp.getVectorType();
1601 if (targetType.getRank() <= 1)
1604 FailureOr<size_t> maybeDimsToDrop =
1605 getTransferFoldableInnerUnitDims(srcType, targetType);
1606 if (failed(maybeDimsToDrop))
1609 size_t dimsToDrop = maybeDimsToDrop.value();
1610 if (dimsToDrop == 0)
1613 auto inBounds = writeOp.getInBoundsValues();
1614 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1615 if (llvm::is_contained(droppedInBounds,
false))
1618 auto resultTargetVecType =
1620 targetType.getElementType(),
1621 targetType.getScalableDims().drop_back(dimsToDrop));
1630 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1631 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1634 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1636 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1637 loc, resultMemrefType, writeOp.getBase(), offsets, sizes, strides);
1639 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1641 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1642 loc, resultTargetVecType, writeOp.getVector());
1644 writeOp, shapeCast, rankedReducedView,
1647 Value(), inBoundsAttr);
1655 struct CanonicalizeContractMatmulToMMT final
1659 using FilterConstraintType =
1660 std::function<LogicalResult(vector::ContractionOp op)>;
1663 FilterConstraintType constraint)
1665 filter(std::move(constraint)) {}
1667 LogicalResult matchAndRewrite(vector::ContractionOp op,
1669 if (failed(filter(op)))
1673 Value lhs = op.getLhs();
1674 Value rhs = op.getRhs();
1675 Value res = op.getAcc();
1679 auto infer = [&](MapList m) {
1686 static constexpr std::array<int64_t, 2> perm = {1, 0};
1687 auto iteratorTypes = op.getIteratorTypes().getValue();
1689 if (iteratorTypes.size() != 3 ||
1696 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1697 if (maps == canonicalForm)
1702 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1703 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1705 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1706 VectorType newType =
1707 cast<VectorType>(trans.
getType())
1708 .
clone(cast<VectorType>(mat.getType()).getElementType());
1709 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1711 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1713 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1714 VectorType newType =
1716 cast<VectorType>(mat.getType()).getElementType());
1717 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1719 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1722 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1723 rhs = createTranspose(rhs);
1724 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1725 lhs = createTranspose(lhs);
1726 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1727 rhs = createTranspose(rhs);
1728 lhs = createTranspose(lhs);
1729 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1730 std::swap(rhs, lhs);
1731 rhs = createTranspose(rhs);
1732 lhs = createTranspose(lhs);
1733 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1734 std::swap(rhs, lhs);
1735 rhs = createTranspose(rhs);
1736 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1737 std::swap(lhs, rhs);
1738 lhs = createTranspose(lhs);
1739 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1740 std::swap(lhs, rhs);
1746 op.getIteratorTypes());
1751 FilterConstraintType filter;
1771 template <
typename ExtOp>
1772 struct FoldArithExtIntoContractionOp
1776 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1779 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1780 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1782 if (!lhsDefOp || !rhsDefOp) {
1784 "no defining op on contract operands");
1788 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1789 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1790 contractOp.getIteratorTypesAttr());
1809 LogicalResult matchAndRewrite(vector::ReductionOp op,
1812 if (op.getKind() != vector::CombiningKind::ADD)
1816 Value acc = op.getAcc();
1823 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1824 if (!parentReduction)
1829 if (isa<IntegerType>(acc.
getType())) {
1831 loc, parentReduction.getVector(), op.getVector());
1833 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1837 parentReduction.getAcc());
1847 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1848 auto inVecShape = inVecTy.getShape();
1851 for (
auto [dim, isScalable] :
1852 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1853 if (dim == 1 && !isScalable)
1856 newShape.push_back(dim);
1857 newScalableDims.push_back(isScalable);
1860 if (newShape.empty()) {
1861 newShape.push_back(1);
1862 newScalableDims.push_back(
false);
1865 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1893 struct DropUnitDimFromElementwiseOps final
1896 LogicalResult matchAndRewrite(
Operation *op,
1902 if (!resultVectorType)
1909 if (!sourceVectorType)
1911 if (sourceVectorType.getRank() < 2)
1917 auto opVectorType = cast<VectorType>(operand.getType());
1918 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1919 if (newVType == opVectorType)
1922 auto opSC = rewriter.
create<vector::ShapeCastOp>(loc, newVType, operand);
1923 newOperands.push_back(opSC);
1926 VectorType newResultVectorType =
1927 dropNonScalableUnitDimFromType(resultVectorType);
1931 newResultVectorType, op->
getAttrs());
1960 struct DropUnitDimsFromTransposeOp final
1964 LogicalResult matchAndRewrite(vector::TransposeOp op,
1966 VectorType sourceType = op.getSourceVectorType();
1967 VectorType sourceTypeWithoutUnitDims =
1968 dropNonScalableUnitDimFromType(sourceType);
1970 if (sourceType == sourceTypeWithoutUnitDims)
1976 int64_t droppedDims = 0;
1978 droppedDimsBefore[i] = droppedDims;
1979 if (dim == std::make_tuple(1,
false))
1986 for (int64_t idx : perm) {
1987 if (sourceDims[idx] == std::make_tuple(1,
false))
1989 newPerm.push_back(idx - droppedDimsBefore[idx]);
1995 if (newPerm.empty()) {
1996 newPerm.push_back(0);
2001 auto dropDimsShapeCast = rewriter.
create<vector::ShapeCastOp>(
2002 loc, sourceTypeWithoutUnitDims, op.getVector());
2004 auto transposeWithoutUnitDims =
2005 rewriter.
create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
2008 op, op.getResultVectorType(), transposeWithoutUnitDims);
2041 LogicalResult matchAndRewrite(scf::ForOp forOp,
2045 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
2046 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2050 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2051 if (vectorType == newVectorType)
2056 return b.
create<vector::ShapeCastOp>(loc, type, source);
2060 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2063 replacement, castFn));
2086 LogicalResult matchAndRewrite(vector::ReductionOp op,
2089 if (op.getKind() != vector::CombiningKind::ADD)
2092 Type elemType = op.getSourceVectorType().getElementType();
2095 if (!isa<FloatType>(elemType))
2108 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
2126 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
2128 unsigned maxNumElementsToExtract,
2131 maxNumElementsToExtract(maxNumElementsToExtract) {}
2133 LogicalResult matchAndRewrite(vector::ReductionOp op,
2135 VectorType type = op.getSourceVectorType();
2136 if (type.isScalable() || op.isMasked())
2138 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2140 int64_t numElems = type.getNumElements();
2141 if (numElems > maxNumElementsToExtract) {
2143 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2144 "(max allowed: {1})",
2145 numElems, maxNumElementsToExtract));
2151 extractedElem = rewriter.
create<vector::ExtractOp>(
2152 loc, op.getVector(),
static_cast<int64_t
>(idx));
2154 Value res = extracted.front();
2155 for (
auto extractedElem : llvm::drop_begin(extracted))
2157 extractedElem, op.getFastmathAttr());
2158 if (
Value acc = op.getAcc())
2160 op.getFastmathAttr());
2167 unsigned maxNumElementsToExtract = 0;
2186 template <
typename MulOpType>
2187 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
2191 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2194 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2197 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2198 return srcType && srcType.getRank() != 2;
2201 LogicalResult matchAndRewrite(MulOpType mulOp,
2203 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
2206 if (resType.getRank() != 2)
2211 auto matchOuterProduct =
2213 Value operandB) -> FailureOr<vector::OuterProductOp> {
2214 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2219 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2222 auto broadcastedLhs =
2223 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2224 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2227 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2228 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2231 return rewriter.
create<vector::OuterProductOp>(
2232 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2233 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2236 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2237 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2239 if (failed(maybeOuterP))
2240 maybeOuterP = matchOuterProduct(rhs, lhs);
2241 if (failed(maybeOuterP))
2243 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2252 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2253 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2257 void mlir::vector::populateVectorMaskMaterializationPatterns(
2260 patterns.add<VectorCreateMaskOpConversion,
2261 MaterializeTransferMask<vector::TransferReadOp>,
2262 MaterializeTransferMask<vector::TransferWriteOp>>(
2263 patterns.getContext(), force32BitVectorIndices, benefit);
2267 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2274 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2275 DropUnitDimsFromTransposeOp>(
patterns.getContext(), benefit);
2278 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2280 patterns.add<BubbleDownVectorBitCastForExtract,
2281 BubbleDownBitCastForStridedSliceExtract,
2282 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2286 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2288 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2290 std::move(controlFn), benefit);
2295 std::function<LogicalResult(vector::ContractionOp)> constraint,
2297 patterns.add<CanonicalizeContractMatmulToMMT>(
patterns.getContext(), benefit,
2298 std::move(constraint));
2303 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2304 CombineContractABTranspose, CombineContractResultTranspose>(
2311 patterns.add<DropInnerMostUnitDimsTransferRead,
2312 DropInnerMostUnitDimsTransferWrite>(
patterns.getContext(),
2318 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2319 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2326 patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2330 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2337 void mlir::vector::populateBreakDownVectorReductionPatterns(
2341 maxNumElementsToExtract, benefit);
2346 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2347 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2355 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ArithBuilder specialized specifically for tensor/memref indexing calculations.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.