24 #include <type_traits>
28 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPSPASS
29 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
40 return op.getSource();
49 class TransferReadOfExtractSliceOpFolder final
52 using MaskableOpRewritePattern::MaskableOpRewritePattern;
54 FailureOr<mlir::Value>
55 matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
56 vector::MaskingOpInterface maskOp,
61 class InsertSliceOfTransferWriteOpFolder final
66 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
71 doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
75 template <
typename XferOp,
typename ExtractOrInsertOp>
78 ExtractOrInsertOp extractOrInsertSliceOp) {
79 if (xferOp.hasOutOfBoundsDim())
83 if (!extractOrInsertSliceOp.hasUnitStride()) {
85 xferOp,
"non-1 stride insert/extract, requires keeping track of "
86 "strides, this may result in needing to insert "
87 "vector.insert_strided_slice/extract_strided_slice ops");
92 FailureOr<mlir::Value>
93 TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
94 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
101 LogicalResult preconditionResult =
104 if (
failed(preconditionResult))
108 readOp.getIndices().end());
111 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
112 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
113 indices, sourceIndices);
115 Operation *newOp = vector::TransferReadOp::create(
116 rewriter, readOp.getLoc(), readOp.getVectorType(),
117 extractSliceOp.getSource(), sourceIndices,
119 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
120 extractSliceOp.getDroppedDims())),
122 Value(), readOp.getInBoundsAttr());
128 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
129 tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter)
const {
131 .template getDefiningOp<vector::TransferWriteOp>();
135 LogicalResult preconditionResult =
138 if (
failed(preconditionResult))
139 return preconditionResult;
141 if (!doesTransferWriteCoverInsertSlice(writeOp))
143 insertSliceOp,
"transfer_write does not cover insert_slice");
146 writeOp.getIndices().end());
149 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
150 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
154 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
156 insertSliceOp.getDestType().getRank(),
157 insertSliceOp.getDroppedDims())),
158 writeOp.getInBoundsAttr());
163 bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
164 vector::TransferWriteOp writeOp) {
165 if (writeOp.getShapedType().hasStaticShape())
166 return llvm::equal(writeOp.getVectorType().getShape(),
167 writeOp.getShapedType().getShape());
174 template <
typename OpTy>
180 auto sourceInsertSliceOp =
181 insertSliceOp.getSource()
182 .template getDefiningOp<tensor::InsertSliceOp>();
183 if (!sourceInsertSliceOp)
187 if (!insertSliceOp.hasUnitStride()) {
189 "requires unit strides");
191 if (!sourceInsertSliceOp.hasUnitStride()) {
193 "requires unit strides");
197 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
198 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
201 if (insertSliceOp.getMixedSizes()[d] !=
202 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
205 "requires matching sizes to fold, otherwise a copy is needed");
215 sourceInsertSliceOp.getMixedSizes(),
216 droppedDims, resolvedSizes);
221 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
223 insertSliceOp->template getParentOfType<scf::InParallelOp>());
232 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
233 insertSliceOp.getMixedStrides(), droppedDims,
234 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
240 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
242 insertSliceOp.getMixedStrides());
257 patterns.add<TransferReadOfExtractSliceOpFolder,
258 InsertSliceOfTransferWriteOpFolder>(
patterns.getContext());
267 struct FoldTensorSubsetOpsPass final
268 :
public tensor::impl::FoldTensorSubsetOpsPassBase<
269 FoldTensorSubsetOpsPass> {
270 void runOnOperation()
override;
275 void FoldTensorSubsetOpsPass::runOnOperation() {
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
static MLIRContext * getContext(OpFoldResult val)
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.