26 #include "llvm/ADT/TypeSwitch.h"
27 #include <type_traits>
31 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
32 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
39 return op.getSource();
43 return op.getSource();
52 class TransferReadOfExtractSliceOpFolder final
55 using MaskableOpRewritePattern::MaskableOpRewritePattern;
57 FailureOr<mlir::Value>
58 matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
59 vector::MaskingOpInterface maskOp,
64 class InsertSliceOfTransferWriteOpFolder final
69 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
74 doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
78 template <
typename XferOp,
typename ExtractOrInsertOp>
81 ExtractOrInsertOp extractOrInsertSliceOp) {
82 if (xferOp.hasOutOfBoundsDim())
86 if (!extractOrInsertSliceOp.hasUnitStride()) {
88 xferOp,
"non-1 stride insert/extract, requires keeping track of "
89 "strides, this may result in needing to insert "
90 "vector.insert_strided_slice/extract_strided_slice ops");
95 FailureOr<mlir::Value>
96 TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
97 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
104 LogicalResult preconditionResult =
107 if (failed(preconditionResult))
111 readOp.getIndices().end());
114 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
115 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
116 indices, sourceIndices);
119 readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
122 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
123 extractSliceOp.getDroppedDims())),
125 Value(), readOp.getInBoundsAttr());
131 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
132 tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter)
const {
134 .template getDefiningOp<vector::TransferWriteOp>();
138 LogicalResult preconditionResult =
141 if (failed(preconditionResult))
142 return preconditionResult;
144 if (!doesTransferWriteCoverInsertSlice(writeOp))
146 insertSliceOp,
"transfer_write does not cover insert_slice");
149 writeOp.getIndices().end());
152 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
153 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
157 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
159 insertSliceOp.getDestType().getRank(),
160 insertSliceOp.getDroppedDims())),
161 writeOp.getInBoundsAttr());
166 bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
167 vector::TransferWriteOp writeOp) {
168 if (writeOp.getShapedType().hasStaticShape())
169 return llvm::equal(writeOp.getVectorType().getShape(),
170 writeOp.getShapedType().getShape());
177 template <
typename OpTy>
183 auto sourceInsertSliceOp =
184 insertSliceOp.getSource()
185 .template getDefiningOp<tensor::InsertSliceOp>();
186 if (!sourceInsertSliceOp)
190 if (!insertSliceOp.hasUnitStride()) {
192 "requires unit strides");
194 if (!sourceInsertSliceOp.hasUnitStride()) {
196 "requires unit strides");
200 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
201 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
204 if (insertSliceOp.getMixedSizes()[d] !=
205 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
208 "requires matching sizes to fold, otherwise a copy is needed");
218 sourceInsertSliceOp.getMixedSizes(),
219 droppedDims, resolvedSizes);
224 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
226 insertSliceOp->template getParentOfType<scf::InParallelOp>());
235 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
236 insertSliceOp.getMixedStrides(), droppedDims,
237 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
243 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
245 insertSliceOp.getMixedStrides());
260 patterns.
add<TransferReadOfExtractSliceOpFolder,
261 InsertSliceOfTransferWriteOpFolder>(patterns.
getContext());
270 struct FoldTensorSubsetOpsPass final
271 :
public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
272 void runOnOperation()
override;
277 void FoldTensorSubsetOpsPass::runOnOperation() {
284 return std::make_unique<FoldTensorSubsetOpsPass>();
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
static MLIRContext * getContext(OpFoldResult val)
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
std::unique_ptr< Pass > createFoldTensorSubsetOpsPass()
Creates an instance of the tensor subset folding pass.
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.