24 #include "llvm/ADT/TypeSwitch.h"
25 #include <type_traits>
29 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
30 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
37 return op.getSource();
41 return op.getSource();
50 class TransferReadOfExtractSliceOpFolder final
60 class InsertSliceOfTransferWriteOpFolder final
65 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
70 template <
typename XferOp,
typename ExtractOrInsertOp>
73 ExtractOrInsertOp extractOrInsertSliceOp) {
74 if (xferOp.hasOutOfBoundsDim())
78 if (!extractOrInsertSliceOp.hasUnitStride()) {
80 xferOp,
"non-1 stride insert/extract, requires keeping track of "
81 "strides, this may result in needing to insert "
82 "vector.insert_strided_slice/extract_strided_slice ops");
87 LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
97 if (
failed(preconditionResult))
98 return preconditionResult;
101 readOp.getIndices().end());
104 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
105 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
106 indices, sourceIndices);
109 readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
111 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
112 extractSliceOp.getDroppedDims())),
114 Value(), readOp.getInBoundsAttr());
119 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
120 tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter)
const {
122 .template getDefiningOp<vector::TransferWriteOp>();
129 if (
failed(preconditionResult))
130 return preconditionResult;
133 writeOp.getIndices().end());
136 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
137 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
141 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
143 insertSliceOp.getDestType().getRank(),
144 insertSliceOp.getDroppedDims())),
145 writeOp.getInBoundsAttr());
150 template <
typename OpTy>
156 auto sourceInsertSliceOp =
157 insertSliceOp.getSource()
158 .template getDefiningOp<tensor::InsertSliceOp>();
159 if (!sourceInsertSliceOp)
163 if (!insertSliceOp.hasUnitStride()) {
165 "requires unit strides");
167 if (!sourceInsertSliceOp.hasUnitStride()) {
169 "requires unit strides");
173 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
174 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
177 if (insertSliceOp.getMixedSizes()[d] !=
178 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
181 "requires matching sizes to fold, otherwise a copy is needed");
191 sourceInsertSliceOp.getMixedSizes(),
192 droppedDims, resolvedSizes);
197 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
199 insertSliceOp->template getParentOfType<scf::InParallelOp>());
208 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
209 insertSliceOp.getMixedStrides(), droppedDims,
210 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
216 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
218 insertSliceOp.getMixedStrides());
233 patterns.
add<TransferReadOfExtractSliceOpFolder,
234 InsertSliceOfTransferWriteOpFolder>(patterns.
getContext());
243 struct FoldTensorSubsetOpsPass final
244 :
public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
245 void runOnOperation()
override;
250 void FoldTensorSubsetOpsPass::runOnOperation() {
257 return std::make_unique<FoldTensorSubsetOpsPass>();
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
static MLIRContext * getContext(OpFoldResult val)
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
std::unique_ptr< Pass > createFoldTensorSubsetOpsPass()
Creates an instance of the tensor subset folding pass.
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...