18 #include <type_traits>
42 #include "llvm/ADT/DenseSet.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/raw_ostream.h"
49 #define DEBUG_TYPE "vector-to-vector"
54 template <
typename IntType>
56 return llvm::to_vector<4>(llvm::map_range(
57 arrayAttr.getAsRange<IntegerAttr>(),
58 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
92 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
95 auto sourceVectorType =
96 shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
97 auto resultVectorType =
98 shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
99 if (!sourceVectorType || !resultVectorType)
103 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
104 shapeCastOp.getSource().getDefiningOp());
105 if (!sourceShapeCastOp)
107 auto operandSourceVectorType =
108 sourceShapeCastOp.getSource().getType().cast<VectorType>();
109 auto operandResultVectorType = sourceShapeCastOp.getType();
112 if (operandSourceVectorType != resultVectorType ||
113 operandResultVectorType != sourceVectorType)
116 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
138 struct MultiReduceToContract
142 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
144 if (reduceOp.getKind() != vector::CombiningKind::ADD)
146 Operation *mulOp = reduceOp.getSource().getDefiningOp();
147 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
154 if (!isReduceDim.value()) {
155 iteratorTypes.push_back(vector::IteratorType::parallel);
158 iteratorTypes.push_back(vector::IteratorType::reduction);
162 0, exprs, reduceOp.getContext());
168 return IteratorTypeAttr::get(rewriter.getContext(), t);
197 struct CombineContractABTranspose final
201 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
204 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
205 Value lhs = contractOp.getLhs();
206 Value rhs = contractOp.getRhs();
208 bool changed =
false;
209 for (
Value *operand : {&lhs, &rhs}) {
211 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
215 extractVector<unsigned>(transposeOp.getTransp()),
216 contractOp.getContext());
218 *operand = transposeOp.getVector();
224 contractOp, lhs, rhs, contractOp.getAcc(),
262 struct CombineContractResultTranspose final
268 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
269 if (!contractOp || !contractOp->hasOneUse())
272 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
277 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
283 extractVector<unsigned>(accTOp.getTransp()), context);
288 extractVector<unsigned>(resTOp.getTransp()), context);
289 auto combinedResMap = resTMap.compose(contractMap);
296 maps.back() = combinedResMap;
299 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
327 struct CombineContractBroadcast
331 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
334 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
335 Value lhs = contractOp.getLhs();
336 Value rhs = contractOp.getRhs();
338 bool changed =
false;
339 for (
Value *operand : {&lhs, &rhs}) {
347 srcType.getRank() ==
broadcast.getResultVectorType().getRank())
350 broadcast.getResultVectorType().getRank() - srcType.getRank();
351 bool innerDimBroadcast =
false;
354 if (dim.value() !=
broadcast.getResultVectorType().getDimSize(
355 rankDiff + dim.index())) {
356 innerDimBroadcast =
true;
359 originalDims.push_back(
364 if (innerDimBroadcast)
369 bool nonUnitDimReductionBroadcast =
false;
370 for (int64_t i = 0; i < rankDiff; ++i) {
371 if (
broadcast.getResultVectorType().getDimSize(i) != 1 &&
374 nonUnitDimReductionBroadcast =
true;
378 if (nonUnitDimReductionBroadcast)
384 map = broadcastMap.
compose(map);
400 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
401 if (!unusedDimsBitVector.test(i))
402 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
409 bool hasReductionIteratorApplyingOnBothSides =
false;
410 for (
unsigned i = 0; i < iterators.size(); ++i) {
414 hasReductionIteratorApplyingOnBothSides =
true;
418 if (!hasReductionIteratorApplyingOnBothSides)
426 contractOp, lhs, rhs, contractOp.getAcc(),
445 struct ReorderCastOpsOnBroadcast
451 if (op->getNumOperands() != 1)
453 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
458 if (
auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
459 castResTy = VectorType::get(vecTy.getShape(), castResTy);
461 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
462 bcastOp.getSource(), castResTy, op->getAttrs());
464 op, op->getResult(0).getType(), castOp->getResult(0));
483 struct ReorderElementwiseOpsOnTranspose final
499 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
501 transposeMaps.push_back(transposeOp.getTransp());
502 srcType = transposeOp.getSourceVectorType();
507 if (transposeMaps.empty())
512 if (!llvm::all_equal(transposeMaps))
520 auto order = extractVector<unsigned>(transposeMaps.front());
522 for (
int i = 0, e = order.size(); i < e; ++i)
523 invOrder[order[i]] = i;
526 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
528 srcValues.push_back(transposeOp.getVector());
531 auto vectorType = VectorType::get(
533 operand.getType().cast<VectorType>().getElementType());
534 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
535 operand.getLoc(), vectorType, operand,
540 auto vectorType = VectorType::get(
548 transposeMaps.front());
555 return llvm::to_vector<4>(
556 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
557 [](IntegerAttr attr) { return attr.getInt(); }));
569 struct BubbleDownVectorBitCastForExtract
576 if (extractOp.getSourceVectorType().getRank() != 1)
579 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
583 VectorType castSrcType = castOp.getSourceVectorType();
584 VectorType castDstType = castOp.getResultVectorType();
585 assert(castSrcType.getRank() == castDstType.getRank());
590 if (castSrcType.getNumElements() == 1)
595 if (castSrcType.getNumElements() > castDstType.getNumElements())
598 unsigned expandRatio =
599 castDstType.getNumElements() / castSrcType.getNumElements();
602 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
609 VectorType oneScalarType =
610 VectorType::get({1}, castSrcType.getElementType());
611 Value packedValue = rewriter.
create<vector::ExtractOp>(
612 extractOp.getLoc(), oneScalarType, castOp.getSource(),
617 VectorType packedType =
618 VectorType::get({expandRatio}, castDstType.getElementType());
619 Value castedValue = rewriter.
create<vector::BitCastOp>(
620 extractOp.getLoc(), packedType, packedValue);
624 extractOp, extractOp.getType(), castedValue,
643 struct BubbleDownBitCastForStridedSliceExtract
647 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
653 VectorType castSrcType = castOp.getSourceVectorType();
654 VectorType castDstType = castOp.getResultVectorType();
655 assert(castSrcType.getRank() == castDstType.getRank());
657 int64_t castSrcLastDim = castSrcType.getShape().back();
658 int64_t castDstLastDim = castDstType.getShape().back();
660 if (castSrcLastDim > castDstLastDim)
664 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
665 [](
const APInt &val) { return !val.isOne(); }))
668 unsigned rank = extractOp.getSourceVectorType().getRank();
669 assert(castDstLastDim % castSrcLastDim == 0);
670 int64_t expandRatio = castDstLastDim / castSrcLastDim;
676 ArrayAttr newOffsets = extractOp.getOffsets();
677 if (newOffsets.size() == rank) {
679 if (offsets.back() % expandRatio != 0)
681 offsets.back() = offsets.back() / expandRatio;
686 ArrayAttr newSizes = extractOp.getSizes();
687 if (newSizes.size() == rank) {
689 if (sizes.back() % expandRatio != 0)
691 sizes.back() = sizes.back() / expandRatio;
696 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
697 dims.back() = dims.back() / expandRatio;
698 VectorType newExtractType =
699 VectorType::get(dims, castSrcType.getElementType());
701 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
702 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
703 newSizes, extractOp.getStrides());
706 extractOp, extractOp.getType(), newExtractOp);
723 struct BubbleUpBitCastForStridedSliceInsert
729 VectorType castSrcType = bitcastOp.getSourceVectorType();
730 VectorType castDstType = bitcastOp.getResultVectorType();
731 assert(castSrcType.getRank() == castDstType.getRank());
733 if (castSrcType.getRank() == 0)
736 int64_t castSrcLastDim = castSrcType.getShape().back();
737 int64_t castDstLastDim = castDstType.getShape().back();
739 if (castSrcLastDim < castDstLastDim)
742 assert(castSrcLastDim % castDstLastDim == 0);
743 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
746 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
751 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
752 [](
const APInt &val) { return !val.isOne(); }))
755 unsigned rank = insertOp.getSourceVectorType().getRank();
758 if (rank != insertOp.getDestVectorType().getRank())
762 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
763 unsigned destinationWidth =
764 castDstType.getElementType().getIntOrFloatBitWidth();
765 unsigned numElements = destinationWidth / sourceWidth;
766 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
769 ArrayAttr newOffsets = insertOp.getOffsets();
770 assert(newOffsets.size() == rank);
772 if (offsets.back() % shrinkRatio != 0)
774 offsets.back() = offsets.back() / shrinkRatio;
778 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
779 srcDims.back() = srcDims.back() / shrinkRatio;
780 VectorType newCastSrcType =
781 VectorType::get(srcDims, castDstType.getElementType());
783 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
784 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
787 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
788 dstDims.back() = dstDims.back() / shrinkRatio;
789 VectorType newCastDstType =
790 VectorType::get(dstDims, castDstType.getElementType());
792 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
793 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
796 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
797 insertOp.getStrides());
813 bool force32BitVectorIndices, int64_t dim,
822 if (dim == 0 && force32BitVectorIndices) {
825 }
else if (dim == 0) {
828 }
else if (force32BitVectorIndices) {
830 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
833 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
835 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
840 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
845 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
846 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
850 template <
typename ConcreteOp>
853 explicit MaterializeTransferMask(
MLIRContext *context,
bool enableIndexOpt,
856 force32BitVectorIndices(enableIndexOpt) {}
860 if (!xferOp.hasOutOfBoundsDim())
863 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
867 VectorType vtp = xferOp.getVectorType();
874 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
875 Value off = xferOp.getIndices()[lastIndex];
879 Value mask = rewriter.
create<vector::CreateMaskOp>(
881 VectorType::get(vtp.getShape(), rewriter.
getI1Type(),
882 vtp.getNumScalableDims()),
884 if (xferOp.getMask()) {
886 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
890 xferOp.getMaskMutable().assign(mask);
898 const bool force32BitVectorIndices;
902 class VectorCreateMaskOpConversion
905 explicit VectorCreateMaskOpConversion(
MLIRContext *context,
909 force32BitVectorIndices(enableIndexOpt) {}
913 auto dstType = op.getType();
914 if (dstType.cast<VectorType>().isScalable())
916 int64_t rank = dstType.getRank();
920 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
921 rank == 0 ? 0 : dstType.getDimSize(0),
927 const bool force32BitVectorIndices;
931 class DropInnerMostUnitDims :
public OpRewritePattern<vector::TransferReadOp> {
937 if (readOp.getTransferRank() == 0)
941 if (readOp.getMask())
944 auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
945 if (!srcType || !srcType.hasStaticShape())
948 if (!readOp.getPermutationMap().isMinorIdentity())
951 auto targetType = readOp.getVectorType();
952 if (targetType.getRank() <= 1)
960 size_t dimsToDrop = 0;
961 for (
size_t i = 1; i < srcStrides.size(); ++i) {
962 int dim = srcType.getRank() - i - 1;
963 if (srcStrides[dim] == 1) {
972 auto resultTargetVecType =
973 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
974 targetType.getElementType());
976 MemRefType resultMemrefType;
977 MemRefLayoutAttrInterface layout = srcType.getLayout();
978 if (layout.isa<AffineMapAttr>() && layout.isIdentity()) {
979 resultMemrefType = MemRefType::get(
980 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
981 nullptr, srcType.getMemorySpace());
983 MemRefLayoutAttrInterface updatedLayout;
984 if (
auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
986 llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
987 updatedLayout = StridedLayoutAttr::get(strided.getContext(),
988 strided.getOffset(), strides);
990 AffineMap map = srcType.getLayout().getAffineMap();
992 for (
size_t i = 0; i < dimsToDrop; ++i) {
993 int dim = srcType.getRank() - i - 1;
999 resultMemrefType = MemRefType::get(
1000 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1001 updatedLayout, srcType.getMemorySpace());
1004 auto loc = readOp.getLoc();
1008 ArrayAttr inBoundsAttr =
1009 readOp.getInBounds()
1011 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1013 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
1014 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
1017 rankedReducedView.
getType().
cast<ShapedType>(), resultTargetVecType);
1018 Value result = rewriter.
create<vector::TransferReadOp>(
1019 loc, resultTargetVecType, rankedReducedView,
1020 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1021 readOp.getPadding(),
1023 Value(), inBoundsAttr);
1033 struct CanonicalizeContractMatmulToMMT final
1037 using FilterConstraintType =
1041 FilterConstraintType constraint)
1043 filter(std::move(constraint)) {}
1048 if (!op.getMasks().empty())
1055 Value lhs = op.getLhs();
1056 Value rhs = op.getRhs();
1057 Value res = op.getAcc();
1066 static constexpr std::array<int64_t, 2> perm = {1, 0};
1067 auto iteratorTypes = op.getIteratorTypes().getValue();
1069 if (iteratorTypes.size() != 3 ||
1076 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1077 if (maps == canonicalForm)
1082 auto createTranspose = [&rewriter, loc](
Value mat) ->
Value {
1083 if (
auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1085 rewriter.
create<vector::TransposeOp>(loc, sext.getIn(), perm);
1086 return rewriter.
create<arith::ExtSIOp>(loc, mat.getType(), trans);
1088 if (
auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1090 rewriter.
create<vector::TransposeOp>(loc,
zext.getIn(), perm);
1091 return rewriter.
create<arith::ExtUIOp>(loc, mat.getType(), trans);
1093 return rewriter.
create<vector::TransposeOp>(loc, mat, perm);
1096 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1097 rhs = createTranspose(rhs);
1098 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1099 lhs = createTranspose(lhs);
1100 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1101 rhs = createTranspose(rhs);
1102 lhs = createTranspose(lhs);
1103 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1104 std::swap(rhs, lhs);
1105 rhs = createTranspose(rhs);
1106 lhs = createTranspose(lhs);
1107 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1108 std::swap(rhs, lhs);
1109 rhs = createTranspose(rhs);
1110 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1111 std::swap(lhs, rhs);
1112 lhs = createTranspose(lhs);
1113 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1114 std::swap(lhs, rhs);
1120 op.getIteratorTypes());
1125 FilterConstraintType filter;
1133 patterns.
add<VectorCreateMaskOpConversion,
1134 MaterializeTransferMask<vector::TransferReadOp>,
1135 MaterializeTransferMask<vector::TransferWriteOp>>(
1136 patterns.
getContext(), force32BitVectorIndices, benefit);
1141 patterns.
add<ShapeCastOpFolder>(patterns.
getContext(), benefit);
1146 patterns.
add<BubbleDownVectorBitCastForExtract,
1147 BubbleDownBitCastForStridedSliceExtract,
1148 BubbleUpBitCastForStridedSliceInsert>(patterns.
getContext(),
1154 std::function<
LogicalResult(vector::ContractionOp)> constraint,
1156 patterns.
add<CanonicalizeContractMatmulToMMT>(patterns.
getContext(), benefit,
1157 std::move(constraint));
1162 patterns.
add<MultiReduceToContract, CombineContractBroadcast,
1163 CombineContractABTranspose, CombineContractResultTranspose,
1164 ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1171 patterns.
add<DropInnerMostUnitDims>(patterns.
getContext(), benefit);
1178 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
static uint64_t zext(uint32_t arg)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static uint64_t getFirstIntValue(ArrayAttr attr)
Gets the first integer value from attr, assuming it is an integer array attribute.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...