35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/FormatVariadic.h"
38 #define DEBUG_TYPE "vector-to-vector"
43 template <
typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
79 struct MultiReduceToContract
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
99 iteratorTypes.push_back(vector::IteratorType::reduction);
104 0, exprs, reduceOp.getContext());
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
139 struct CombineContractABTranspose final
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
151 for (
Value *operand : {&lhs, &rhs}) {
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
157 transposeOp.getPermutation(), contractOp.getContext());
159 *operand = transposeOp.getVector();
165 contractOp, lhs, rhs, contractOp.getAcc(),
203 struct CombineContractResultTranspose final
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
230 auto combinedResMap = resTMap.compose(contractMap);
237 maps.back() = combinedResMap;
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
268 struct CombineContractBroadcast
272 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
275 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
276 Value lhs = contractOp.getLhs();
277 Value rhs = contractOp.getRhs();
280 for (
Value *operand : {&lhs, &rhs}) {
286 auto srcType = dyn_cast<VectorType>(
broadcast.getSourceType());
288 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
291 broadcast.getResultVectorType().getRank() - srcType.getRank();
292 bool innerDimBroadcast =
false;
295 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
296 rankDiff + dim.index())) {
297 innerDimBroadcast =
true;
300 originalDims.push_back(
305 if (innerDimBroadcast)
310 bool nonUnitDimReductionBroadcast =
false;
311 for (int64_t i = 0; i < rankDiff; ++i) {
312 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
315 nonUnitDimReductionBroadcast =
true;
319 if (nonUnitDimReductionBroadcast)
325 map = broadcastMap.
compose(map);
341 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
342 if (!unusedDimsBitVector.test(i))
343 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
350 bool hasReductionIteratorApplyingOnBothSides =
false;
351 for (
unsigned i = 0; i < iterators.size(); ++i) {
355 hasReductionIteratorApplyingOnBothSides =
true;
359 if (!hasReductionIteratorApplyingOnBothSides)
367 contractOp, lhs, rhs, contractOp.getAcc(),
386 struct ReorderCastOpsOnBroadcast
390 LogicalResult matchAndRewrite(CastOpInterface op,
392 if (op->getNumOperands() != 1)
394 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
399 if (
auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
400 castResTy = vecTy.clone(castResTy);
402 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
403 bcastOp.getSource(), castResTy, op->getAttrs());
405 op, op->getResult(0).getType(), castOp->getResult(0));
424 struct ReorderElementwiseOpsOnTranspose final
427 LogicalResult matchAndRewrite(
Operation *op,
440 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
442 transposeMaps.push_back(transposeOp.getPermutation());
443 srcType = transposeOp.getSourceVectorType();
448 if (transposeMaps.empty())
453 if (!llvm::all_equal(transposeMaps))
461 auto order = transposeMaps.front();
463 for (
int i = 0, e = order.size(); i < e; ++i)
464 invOrder[order[i]] = i;
467 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
469 srcValues.push_back(transposeOp.getVector());
473 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
474 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
475 operand.getLoc(), vectorType, operand, invOrder));
479 auto vectorType = srcType.clone(
486 transposeMaps.front());
493 return llvm::to_vector<4>(
494 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
495 [](IntegerAttr attr) { return attr.getInt(); }));
507 struct BubbleDownVectorBitCastForExtract
511 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
514 if (extractOp.getSourceVectorType().getRank() != 1)
517 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
521 VectorType castSrcType = castOp.getSourceVectorType();
522 VectorType castDstType = castOp.getResultVectorType();
523 assert(castSrcType.getRank() == castDstType.getRank());
528 if (castSrcType.getNumElements() == 1)
533 if (castSrcType.getNumElements() > castDstType.getNumElements())
536 unsigned expandRatio =
537 castDstType.getNumElements() / castSrcType.getNumElements();
540 auto mixedPos = extractOp.getMixedPosition();
541 if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
543 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
548 Value packedValue = rewriter.
create<vector::ExtractOp>(
549 loc, castOp.getSource(), index / expandRatio);
552 loc, packedVecType, rewriter.
getZeroAttr(packedVecType));
553 packedValue = rewriter.
create<vector::InsertOp>(loc, packedValue, zero,
558 VectorType packedType =
561 rewriter.
create<vector::BitCastOp>(loc, packedType, packedValue);
565 index % expandRatio);
582 struct BubbleDownBitCastForStridedSliceExtract
586 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
588 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
592 VectorType castSrcType = castOp.getSourceVectorType();
593 VectorType castDstType = castOp.getResultVectorType();
594 assert(castSrcType.getRank() == castDstType.getRank());
596 int64_t castSrcLastDim = castSrcType.getShape().back();
597 int64_t castDstLastDim = castDstType.getShape().back();
599 if (castSrcLastDim > castDstLastDim)
603 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
604 [](
const APInt &val) { return !val.isOne(); }))
607 unsigned rank = extractOp.getSourceVectorType().getRank();
608 assert(castDstLastDim % castSrcLastDim == 0);
609 int64_t expandRatio = castDstLastDim / castSrcLastDim;
615 ArrayAttr newOffsets = extractOp.getOffsets();
616 if (newOffsets.size() == rank) {
618 if (offsets.back() % expandRatio != 0)
620 offsets.back() = offsets.back() / expandRatio;
625 ArrayAttr newSizes = extractOp.getSizes();
626 if (newSizes.size() == rank) {
628 if (sizes.back() % expandRatio != 0)
630 sizes.back() = sizes.back() / expandRatio;
635 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
636 dims.back() = dims.back() / expandRatio;
637 VectorType newExtractType =
640 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
641 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
642 newSizes, extractOp.getStrides());
645 extractOp, extractOp.getType(), newExtractOp);
661 struct BubbleUpBitCastForInsert :
public OpRewritePattern<vector::BitCastOp> {
664 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
666 VectorType castSrcType = bitcastOp.getSourceVectorType();
667 VectorType castDstType = bitcastOp.getResultVectorType();
670 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
671 castDstType.isScalable())
674 int64_t castSrcLastDim = castSrcType.getShape().back();
675 int64_t castDstLastDim = castDstType.getShape().back();
676 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
678 if (isNumElemsShrink) {
679 assert(castSrcLastDim % castDstLastDim == 0);
680 ratio = castSrcLastDim / castDstLastDim;
682 assert(castDstLastDim % castSrcLastDim == 0);
683 ratio = castDstLastDim / castSrcLastDim;
686 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
691 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
698 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
699 VectorType newCastSrcType =
701 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
702 bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
706 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
707 VectorType newCastDstType =
711 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
712 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
716 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
732 struct BubbleUpBitCastForStridedSliceInsert
736 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
738 VectorType castSrcType = bitcastOp.getSourceVectorType();
739 VectorType castDstType = bitcastOp.getResultVectorType();
740 assert(castSrcType.getRank() == castDstType.getRank());
742 if (castSrcType.getRank() == 0)
745 int64_t castSrcLastDim = castSrcType.getShape().back();
746 int64_t castDstLastDim = castDstType.getShape().back();
748 if (castSrcLastDim < castDstLastDim)
751 assert(castSrcLastDim % castDstLastDim == 0);
752 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
755 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
760 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
761 [](
const APInt &val) { return !val.isOne(); }))
764 unsigned rank = insertOp.getSourceVectorType().getRank();
767 if (rank != insertOp.getDestVectorType().getRank())
771 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
772 unsigned destinationWidth =
773 castDstType.getElementType().getIntOrFloatBitWidth();
774 unsigned numElements = destinationWidth / sourceWidth;
775 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
778 ArrayAttr newOffsets = insertOp.getOffsets();
779 assert(newOffsets.size() == rank);
781 if (offsets.back() % shrinkRatio != 0)
783 offsets.back() = offsets.back() / shrinkRatio;
787 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
788 srcDims.back() = srcDims.back() / shrinkRatio;
789 VectorType newCastSrcType =
792 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
793 bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
796 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
797 dstDims.back() = dstDims.back() / shrinkRatio;
798 VectorType newCastDstType =
801 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
802 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
805 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
806 insertOp.getStrides());
830 struct BreakDownVectorBitCast :
public OpRewritePattern<vector::BitCastOp> {
835 std::function<
bool(vector::BitCastOp)> controlFn,
839 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
842 if (controlFn && !controlFn(bitcastOp))
845 VectorType castSrcType = bitcastOp.getSourceVectorType();
846 VectorType castDstType = bitcastOp.getResultVectorType();
847 assert(castSrcType.getRank() == castDstType.getRank());
852 if (castSrcType.isScalable())
854 "Scalable vectors are not supported");
857 if (castSrcType.getRank() != 1)
860 int64_t castSrcLastDim = castSrcType.getShape().back();
861 int64_t castDstLastDim = castDstType.getShape().back();
863 if (castSrcLastDim < castDstLastDim)
866 assert(castSrcLastDim % castDstLastDim == 0);
867 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
869 if (castSrcLastDim == shrinkRatio)
873 Type elemType = castDstType.getElementType();
878 Value res = rewriter.
create<SplatOp>(loc, castDstType, zero);
882 VectorType newCastDstType =
884 castDstType.getElementType());
886 for (
int i = 0, e = shrinkRatio; i < e; ++i) {
887 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
889 sliceShape, strides);
891 rewriter.
create<BitCastOp>(loc, newCastDstType, extracted);
892 res = rewriter.
create<InsertStridedSliceOp>(
901 std::function<bool(BitCastOp)> controlFn;
920 struct ReorderElementwiseOpsOnBroadcast final
923 LogicalResult matchAndRewrite(
Operation *op,
931 op,
"Op doesn't have ElementwiseMappableTraits");
936 "result and operand type mismatch");
937 if (isa<vector::FMAOp>(op)) {
940 "Op only accepts vector types - not supported as broadcast source "
941 "might be a scalar");
946 if (!lhsBcastOrSplat ||
947 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
949 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
956 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
958 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
959 auto splat = val.getDefiningOp<vector::SplatOp>();
961 return (splat.getOperand().getType() == lhsBcastOrSplatType);
971 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
977 lhsBcastOrSplatType, op->
getAttrs());
1004 class ExtractOpFromElementwise final
1009 LogicalResult matchAndRewrite(vector::ExtractOp op,
1011 Operation *eltwise = op.getVector().getDefiningOp();
1016 isa<vector::FMAOp>(eltwise))
1028 Type dstType = op.getType();
1037 Value newArg = rewriter.
create<vector::ExtractOp>(loc, arg, pos);
1038 mapping.
map(arg, newArg);
1052 static bool isSupportedMemSinkElementType(
Type type) {
1053 if (isa<IndexType>(type))
1074 class ExtractOpFromLoad final :
public OpRewritePattern<vector::ExtractOp> {
1078 LogicalResult matchAndRewrite(vector::ExtractOp op,
1080 auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1085 if (!loadOp->hasOneUse())
1088 VectorType loadVecType = loadOp.getVectorType();
1089 if (loadVecType.isScalable())
1091 "scalable vectors are not supported");
1093 MemRefType memType = loadOp.getMemRefType();
1097 if (!isSupportedMemSinkElementType(memType.getElementType()))
1100 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1104 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1105 int64_t finalRank = 0;
1107 finalRank = extractVecType.getRank();
1119 for (
auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1125 indices[i] = idxBuilderf.add(indices[i], offset);
1128 Value base = loadOp.getBase();
1129 if (extractVecType) {
1152 class StoreOpFromSplatOrBroadcast final
1157 LogicalResult matchAndRewrite(vector::StoreOp op,
1159 VectorType vecType = op.getVectorType();
1160 if (vecType.isScalable())
1162 "scalable vectors are not supported");
1164 if (isa<VectorType>(op.getMemRefType().getElementType()))
1166 op,
"memrefs of vectors are not supported");
1168 if (vecType.getNumElements() != 1)
1170 op,
"only 1-element vectors are supported");
1172 Operation *splat = op.getValueToStore().getDefiningOp();
1173 if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1181 Value base = op.getBase();
1184 if (isa<VectorType>(source.
getType())) {
1204 bool force32BitVectorIndices, int64_t dim,
1213 if (dim == 0 && force32BitVectorIndices) {
1216 }
else if (dim == 0) {
1219 }
else if (force32BitVectorIndices) {
1221 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1224 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1226 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
1231 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
1236 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
1237 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1241 template <
typename ConcreteOp>
1244 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
1247 force32BitVectorIndices(enableIndexOpt) {}
1249 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1251 if (!xferOp.hasOutOfBoundsDim())
1254 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1258 VectorType vtp = xferOp.getVectorType();
1265 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1266 Value off = xferOp.getIndices()[lastIndex];
1270 Value mask = rewriter.
create<vector::CreateMaskOp>(
1273 vtp.getScalableDims()),
1275 if (xferOp.getMask()) {
1277 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
1281 xferOp.getMaskMutable().assign(mask);
1289 const bool force32BitVectorIndices;
1293 class VectorCreateMaskOpConversion
1296 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
1297 bool enableIndexOpt,
1300 force32BitVectorIndices(enableIndexOpt) {}
1302 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1304 auto dstType = op.getType();
1305 if (cast<VectorType>(dstType).isScalable())
1307 int64_t rank = dstType.getRank();
1311 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1312 rank == 0 ? 0 : dstType.getDimSize(0),
1318 const bool force32BitVectorIndices;
1322 static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp,
bool value) {
1323 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1328 assert(denseAttr.getElementType().isInteger(1) &&
"Unexpected type");
1329 return denseAttr.isSplat() && denseAttr.getSplatValue<
bool>() == value;
1346 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1348 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1349 if (!vecType || !vecType.getElementType().isInteger(1))
1353 Value cond = selectOp.getCondition();
1354 if (isa<VectorType>(cond.
getType()))
1358 if (vecType.getRank() != 1 || vecType.isScalable())
1362 if (vecType.getShape()[0] != 1)
1365 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1366 if (!trueConst || !allI1ConstantValuesSetTo(trueConst,
true))
1370 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1371 if (!falseConst || !allI1ConstantValuesSetTo(falseConst,
false))
1375 auto elemType = rewriter.
getIntegerType(vecType.getNumElements());
1397 static FailureOr<size_t>
1398 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1401 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1404 auto isUnitDim = [](VectorType type,
int dim) {
1405 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1412 int rankDiff = srcType.getRank() - vectorType.getRank();
1413 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1416 int dim = vectorType.getRank() - i - 1;
1417 if (srcStrides[dim + rankDiff] != 1 ||
1418 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1426 class DropInnerMostUnitDimsTransferRead
1430 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1433 if (readOp.getTransferRank() == 0)
1437 if (readOp.getMask())
1440 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1444 if (!readOp.getPermutationMap().isMinorIdentity())
1447 auto targetType = readOp.getVectorType();
1448 if (targetType.getRank() <= 1)
1451 FailureOr<size_t> maybeDimsToDrop =
1452 getTransferFoldableInnerUnitDims(srcType, targetType);
1453 if (failed(maybeDimsToDrop))
1456 size_t dimsToDrop = maybeDimsToDrop.value();
1457 if (dimsToDrop == 0)
1460 auto inBounds = readOp.getInBoundsValues();
1461 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1462 if (llvm::is_contained(droppedInBounds,
false))
1465 auto resultTargetVecType =
1467 targetType.getElementType(),
1468 targetType.getScalableDims().drop_back(dimsToDrop));
1470 auto loc = readOp.getLoc();
1477 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1478 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1481 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1482 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1483 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1485 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1486 Value result = rewriter.
create<vector::TransferReadOp>(
1487 loc, resultTargetVecType, rankedReducedView,
1489 readOp.getPadding(),
1491 Value(), inBoundsAttr);
1516 class DropInnerMostUnitDimsTransferWrite
1520 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1523 if (writeOp.getTransferRank() == 0)
1527 if (writeOp.getMask())
1530 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1534 if (!writeOp.getPermutationMap().isMinorIdentity())
1537 auto targetType = writeOp.getVectorType();
1538 if (targetType.getRank() <= 1)
1541 FailureOr<size_t> maybeDimsToDrop =
1542 getTransferFoldableInnerUnitDims(srcType, targetType);
1543 if (failed(maybeDimsToDrop))
1546 size_t dimsToDrop = maybeDimsToDrop.value();
1547 if (dimsToDrop == 0)
1550 auto inBounds = writeOp.getInBoundsValues();
1551 auto droppedInBounds =
ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1552 if (llvm::is_contained(droppedInBounds,
false))
1555 auto resultTargetVecType =
1557 targetType.getElementType(),
1558 targetType.getScalableDims().drop_back(dimsToDrop));
1567 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1568 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1571 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1573 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1574 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1576 cast<ShapedType>(rankedReducedView.
getType()), resultTargetVecType);
1578 auto shapeCast = rewriter.
createOrFold<vector::ShapeCastOp>(
1579 loc, resultTargetVecType, writeOp.getVector());
1581 writeOp, shapeCast, rankedReducedView,
1584 Value(), inBoundsAttr);
1592 struct CanonicalizeContractMatmulToMMT final
1596 using FilterConstraintType =
1597 std::function<LogicalResult(vector::ContractionOp op)>;
1600 FilterConstraintType constraint)
1602 filter(std::move(constraint)) {}
1604 LogicalResult matchAndRewrite(vector::ContractionOp op,
1606 if (failed(filter(op)))
1610 Value lhs = op.getLhs();
1611 Value rhs = op.getRhs();
1612 Value res = op.getAcc();
1616 auto infer = [&](MapList m) {
1623 static constexpr std::array<int64_t, 2> perm = {1, 0};
1624 auto iteratorTypes = op.getIteratorTypes().getValue();
1626 if (iteratorTypes.size() != 3 ||
1633 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1634 if (maps == canonicalForm)
1639 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1640 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1642 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1643 VectorType newType =
1644 cast<VectorType>(trans.
getType())
1645 .
clone(cast<VectorType>(mat.getType()).getElementType());
1646 return rewriter.
create<arith::ExtSIOp>(loc, newType, trans);
1648 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1650 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1651 VectorType newType =
1653 cast<VectorType>(mat.getType()).getElementType());
1654 return rewriter.
create<arith::ExtUIOp>(loc, newType, trans);
1656 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1659 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1660 rhs = createTranspose(rhs);
1661 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1662 lhs = createTranspose(lhs);
1663 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1664 rhs = createTranspose(rhs);
1665 lhs = createTranspose(lhs);
1666 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1667 std::swap(rhs, lhs);
1668 rhs = createTranspose(rhs);
1669 lhs = createTranspose(lhs);
1670 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1671 std::swap(rhs, lhs);
1672 rhs = createTranspose(rhs);
1673 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1674 std::swap(lhs, rhs);
1675 lhs = createTranspose(lhs);
1676 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1677 std::swap(lhs, rhs);
1683 op.getIteratorTypes());
1688 FilterConstraintType filter;
1708 template <
typename ExtOp>
1709 struct FoldArithExtIntoContractionOp
1713 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1716 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1717 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1719 if (!lhsDefOp || !rhsDefOp) {
1721 "no defining op on contract operands");
1725 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1726 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1727 contractOp.getIteratorTypesAttr());
1746 LogicalResult matchAndRewrite(vector::ReductionOp op,
1749 if (op.getKind() != vector::CombiningKind::ADD)
1753 Value acc = op.getAcc();
1760 auto parentReduction = acc.
getDefiningOp<vector::ReductionOp>();
1761 if (!parentReduction)
1766 if (isa<IntegerType>(acc.
getType())) {
1768 loc, parentReduction.getVector(), op.getVector());
1770 vAdd = rewriter.
create<arith::AddFOp>(loc, parentReduction.getVector(),
1774 parentReduction.getAcc());
1784 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1785 auto inVecShape = inVecTy.getShape();
1788 for (
auto [dim, isScalable] :
1789 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1790 if (dim == 1 && !isScalable)
1793 newShape.push_back(dim);
1794 newScalableDims.push_back(isScalable);
1797 if (newShape.empty()) {
1798 newShape.push_back(1);
1799 newScalableDims.push_back(
false);
1802 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1830 struct DropUnitDimFromElementwiseOps final
1833 LogicalResult matchAndRewrite(
Operation *op,
1839 if (!resultVectorType)
1846 if (!sourceVectorType)
1848 if (sourceVectorType.getRank() < 2)
1854 auto opVectorType = cast<VectorType>(operand.getType());
1855 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1856 if (newVType == opVectorType)
1859 auto opSC = rewriter.
create<vector::ShapeCastOp>(loc, newVType, operand);
1860 newOperands.push_back(opSC);
1863 VectorType newResultVectorType =
1864 dropNonScalableUnitDimFromType(resultVectorType);
1868 newResultVectorType, op->
getAttrs());
1897 struct DropUnitDimsFromTransposeOp final
1901 LogicalResult matchAndRewrite(vector::TransposeOp op,
1903 VectorType sourceType = op.getSourceVectorType();
1904 VectorType sourceTypeWithoutUnitDims =
1905 dropNonScalableUnitDimFromType(sourceType);
1907 if (sourceType == sourceTypeWithoutUnitDims)
1913 int64_t droppedDims = 0;
1915 droppedDimsBefore[i] = droppedDims;
1916 if (dim == std::make_tuple(1,
false))
1923 for (int64_t idx : perm) {
1924 if (sourceDims[idx] == std::make_tuple(1,
false))
1926 newPerm.push_back(idx - droppedDimsBefore[idx]);
1932 if (newPerm.empty()) {
1933 newPerm.push_back(0);
1938 auto dropDimsShapeCast = rewriter.
create<vector::ShapeCastOp>(
1939 loc, sourceTypeWithoutUnitDims, op.getVector());
1941 auto transposeWithoutUnitDims =
1942 rewriter.
create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1945 op, op.getResultVectorType(), transposeWithoutUnitDims);
1978 LogicalResult matchAndRewrite(scf::ForOp forOp,
1982 for (
OpOperand &operand : forOp.getInitArgsMutable()) {
1983 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
1987 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
1988 if (vectorType == newVectorType)
1993 return b.
create<vector::ShapeCastOp>(loc, type, source);
1997 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2000 replacement, castFn));
2023 LogicalResult matchAndRewrite(vector::ReductionOp op,
2026 if (op.getKind() != vector::CombiningKind::ADD)
2029 Type elemType = op.getSourceVectorType().getElementType();
2032 if (!isa<FloatType>(elemType))
2045 auto newAdd = rewriter.
create<arith::AddFOp>(vAdd.
getLoc(), addLhs.getLhs(),
2063 struct BreakDownVectorReduction final :
OpRewritePattern<vector::ReductionOp> {
2065 unsigned maxNumElementsToExtract,
2068 maxNumElementsToExtract(maxNumElementsToExtract) {}
2070 LogicalResult matchAndRewrite(vector::ReductionOp op,
2072 VectorType type = op.getSourceVectorType();
2073 if (type.isScalable() || op.isMasked())
2075 assert(type.getRank() == 1 &&
"Expected a 1-d vector");
2077 int64_t numElems = type.getNumElements();
2078 if (numElems > maxNumElementsToExtract) {
2080 op, llvm::formatv(
"has too many vector elements ({0}) to break down "
2081 "(max allowed: {1})",
2082 numElems, maxNumElementsToExtract));
2088 extractedElem = rewriter.
create<vector::ExtractOp>(
2089 loc, op.getVector(),
static_cast<int64_t
>(idx));
2091 Value res = extracted.front();
2092 for (
auto extractedElem : llvm::drop_begin(extracted))
2094 extractedElem, op.getFastmathAttr());
2095 if (
Value acc = op.getAcc())
2097 op.getFastmathAttr());
2104 unsigned maxNumElementsToExtract = 0;
2123 template <
typename MulOpType>
2124 struct FoldArithToVectorOuterProduct :
public OpRewritePattern<MulOpType> {
2128 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp)
const {
2131 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2134 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2135 return srcType && srcType.getRank() != 2;
2138 LogicalResult matchAndRewrite(MulOpType mulOp,
2140 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
2143 if (resType.getRank() != 2)
2148 auto matchOuterProduct =
2150 Value operandB) -> FailureOr<vector::OuterProductOp> {
2151 auto transposedLhs = operandA.
getDefiningOp<vector::TransposeOp>();
2156 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2159 auto broadcastedLhs =
2160 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2161 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2164 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2165 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2168 return rewriter.
create<vector::OuterProductOp>(
2169 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2170 broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
2173 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2174 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2176 if (failed(maybeOuterP))
2177 maybeOuterP = matchOuterProduct(rhs, lhs);
2178 if (failed(maybeOuterP))
2180 rewriter.
replaceOp(mulOp, maybeOuterP->getResult());
2189 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2190 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2194 void mlir::vector::populateVectorMaskMaterializationPatterns(
2197 patterns.add<VectorCreateMaskOpConversion,
2198 MaterializeTransferMask<vector::TransferReadOp>,
2199 MaterializeTransferMask<vector::TransferWriteOp>>(
2200 patterns.getContext(), force32BitVectorIndices, benefit);
2204 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2211 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2212 DropUnitDimsFromTransposeOp>(
patterns.getContext(), benefit);
2215 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2217 patterns.add<BubbleDownVectorBitCastForExtract,
2218 BubbleDownBitCastForStridedSliceExtract,
2219 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2223 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2225 std::function<
bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit) {
2227 std::move(controlFn), benefit);
2232 std::function<LogicalResult(vector::ContractionOp)> constraint,
2234 patterns.add<CanonicalizeContractMatmulToMMT>(
patterns.getContext(), benefit,
2235 std::move(constraint));
2240 patterns.add<MultiReduceToContract, CombineContractBroadcast,
2241 CombineContractABTranspose, CombineContractResultTranspose>(
2248 patterns.add<DropInnerMostUnitDimsTransferRead,
2249 DropInnerMostUnitDimsTransferWrite>(
patterns.getContext(),
2255 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2256 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2263 patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2267 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2274 void mlir::vector::populateBreakDownVectorReductionPatterns(
2278 maxNumElementsToExtract, benefit);
2283 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2284 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2292 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
type_range getType() const
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ArithBuilder specialized specifically for tensor/memref indexing calculations.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...