MLIR 23.0.0git
FoldTensorSubsetOps.cpp
Go to the documentation of this file.
1//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold tensor subset ops with producer / consumers.
10//
11//===----------------------------------------------------------------------===//
12
21#include "mlir/IR/AffineMap.h"
24#include <type_traits>
25
26namespace mlir {
27namespace tensor {
28#define GEN_PASS_DEF_FOLDTENSORSUBSETOPSPASS
29#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
30} // namespace tensor
31} // namespace mlir
32
33using namespace mlir;
34
35static Value getTensorOperand(vector::TransferReadOp op) {
36 return op.getBase();
37}
38
39static Value getTensorOperand(tensor::InsertSliceOp op) {
40 return op.getSource();
41}
42
43//===----------------------------------------------------------------------===//
44// Patterns
45//===----------------------------------------------------------------------===//
46
47namespace {
48/// Merge extract_slice operation with load/transferRead operation.
49class TransferReadOfExtractSliceOpFolder final
50 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
51public:
52 using MaskableOpRewritePattern::MaskableOpRewritePattern;
53
54 FailureOr<mlir::Value>
55 matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
56 vector::MaskingOpInterface maskOp,
57 PatternRewriter &rewriter) const override;
58};
59
60/// Merge insert_slice operation with store/transferWriteOp operation.
61class InsertSliceOfTransferWriteOpFolder final
62 : public OpRewritePattern<tensor::InsertSliceOp> {
63public:
64 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
65
66 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
67 PatternRewriter &rewriter) const override;
68
69private:
70 static bool
71 doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
72};
73} // namespace
74
75template <typename XferOp, typename ExtractOrInsertOp>
77 RewriterBase &rewriter, XferOp xferOp,
78 ExtractOrInsertOp extractOrInsertSliceOp) {
79 if (xferOp.hasOutOfBoundsDim())
80 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
81 if (xferOp.getMask())
82 return rewriter.notifyMatchFailure(xferOp, "masked transfer");
83 if (!extractOrInsertSliceOp.hasUnitStride()) {
84 return rewriter.notifyMatchFailure(
85 xferOp, "non-1 stride insert/extract, requires keeping track of "
86 "strides, this may result in needing to insert "
87 "vector.insert_strided_slice/extract_strided_slice ops");
88 }
89 return success();
90}
91
92FailureOr<mlir::Value>
93TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
94 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
95 PatternRewriter &rewriter) const {
96 auto extractSliceOp =
97 getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
98 if (!extractSliceOp)
99 return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
100
101 LogicalResult preconditionResult =
103 extractSliceOp);
104 if (failed(preconditionResult))
105 return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
106
107 SmallVector<Value> indices(readOp.getIndices().begin(),
108 readOp.getIndices().end());
109 SmallVector<Value> sourceIndices;
110 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
111 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
112 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
113 indices, sourceIndices);
114
115 Operation *newOp = vector::TransferReadOp::create(
116 rewriter, readOp.getLoc(), readOp.getVectorType(),
117 extractSliceOp.getSource(), sourceIndices,
118 AffineMapAttr::get(expandDimsToRank(
119 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
120 extractSliceOp.getDroppedDims())),
121 readOp.getPadding(),
122 /*mask=*/Value(), readOp.getInBoundsAttr());
123 if (maskOp)
124 newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
125 return newOp->getResults()[0];
126}
127
128LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
129 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
130 auto writeOp = getTensorOperand(insertSliceOp)
131 .template getDefiningOp<vector::TransferWriteOp>();
132 if (!writeOp)
133 return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
134
135 LogicalResult preconditionResult =
137 insertSliceOp);
138 if (failed(preconditionResult))
139 return preconditionResult;
140
141 if (!doesTransferWriteCoverInsertSlice(writeOp))
142 return rewriter.notifyMatchFailure(
143 insertSliceOp, "transfer_write does not cover insert_slice");
144
145 SmallVector<Value> indices(writeOp.getIndices().begin(),
146 writeOp.getIndices().end());
147 SmallVector<Value> sourceIndices;
148 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
149 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
150 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
151 sourceIndices);
152
153 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
154 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
155 AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
156 insertSliceOp.getDestType().getRank(),
157 insertSliceOp.getDroppedDims())),
158 writeOp.getInBoundsAttr());
159
160 return success();
161}
162
163bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
164 vector::TransferWriteOp writeOp) {
165 if (writeOp.getShapedType().hasStaticShape())
166 return llvm::equal(writeOp.getVectorType().getShape(),
167 writeOp.getShapedType().getShape());
168
169 // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
170
171 return false;
172}
173
174template <typename OpTy>
177
178 LogicalResult matchAndRewrite(OpTy insertSliceOp,
179 PatternRewriter &rewriter) const override {
180 auto sourceInsertSliceOp =
181 insertSliceOp.getSource()
182 .template getDefiningOp<tensor::InsertSliceOp>();
183 if (!sourceInsertSliceOp)
184 return failure();
185
186 int64_t srcDim = 0;
187 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
188 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
189 if (droppedDims[d])
190 continue;
191 if (insertSliceOp.getMixedSizes()[d] !=
192 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
193 return rewriter.notifyMatchFailure(
194 sourceInsertSliceOp,
195 "requires matching sizes to fold, otherwise a copy is needed");
196 }
197 }
198
199 // If we are inside a ParallelCombining region, temporarily set the
200 // insertion point outside: only ops of ParallelCombiningOpInterface are
201 // allowed in there.
202 if (isa<mlir::ParallelCombiningOpInterface>(insertSliceOp.getOperation())) {
203 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
204 }
205
206 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
208 rewriter, insertSliceOp.getLoc(), insertSliceOp,
209 sourceInsertSliceOp, droppedDims, newOffsets, newSizes,
210 newStrides)))
211 return failure();
212
213 // Reset the insertion point.
214 rewriter.setInsertionPoint(insertSliceOp);
215 // Replace original op.
216 rewriter.replaceOpWithNewOp<OpTy>(
217 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
218 newOffsets, newSizes, newStrides);
219 return success();
220 }
221};
222
224 : public OpRewritePattern<tensor::ExtractSliceOp> {
226
227 LogicalResult matchAndRewrite(tensor::ExtractSliceOp nextOp,
228 PatternRewriter &rewriter) const override {
229 auto prevOp = nextOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
230 if (!prevOp)
231 return failure();
232
233 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
235 rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
236 newOffsets, newSizes, newStrides)))
237 return failure();
238
239 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
240 nextOp, nextOp.getType(), prevOp.getSource(), newOffsets, newSizes,
241 newStrides);
242 return success();
243 }
244};
245
247 RewritePatternSet &patterns) {
248 patterns.add<TransferReadOfExtractSliceOpFolder,
249 InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
250}
251
259
264
265//===----------------------------------------------------------------------===//
266// Pass registration
267//===----------------------------------------------------------------------===//
268
269namespace {
270
271struct FoldTensorSubsetOpsPass final
272 : public tensor::impl::FoldTensorSubsetOpsPassBase<
273 FoldTensorSubsetOpsPass> {
274 void runOnOperation() override;
275};
276
277} // namespace
278
279void FoldTensorSubsetOpsPass::runOnOperation() {
280 RewritePatternSet patterns(&getContext());
282 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
283}
return success()
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
b getContext())
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
result_range getResults()
Definition Operation.h:441
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tensor::ExtractSliceOp nextOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.