MLIR 22.0.0git
FoldTensorSubsetOps.cpp
Go to the documentation of this file.
1//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold tensor subset ops with producer / consumers.
10//
11//===----------------------------------------------------------------------===//
12
24#include <type_traits>
26namespace mlir {
27namespace tensor {
28#define GEN_PASS_DEF_FOLDTENSORSUBSETOPSPASS
29#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
30} // namespace tensor
31} // namespace mlir
33using namespace mlir;
35static Value getTensorOperand(vector::TransferReadOp op) {
36 return op.getBase();
38
39static Value getTensorOperand(tensor::InsertSliceOp op) {
40 return op.getSource();
41}
43//===----------------------------------------------------------------------===//
44// Patterns
45//===----------------------------------------------------------------------===//
46
47namespace {
48/// Merge extract_slice operation with load/transferRead operation.
49class TransferReadOfExtractSliceOpFolder final
50 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
51public:
52 using MaskableOpRewritePattern::MaskableOpRewritePattern;
53
54 FailureOr<mlir::Value>
55 matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
56 vector::MaskingOpInterface maskOp,
57 PatternRewriter &rewriter) const override;
58};
59
60/// Merge insert_slice operation with store/transferWriteOp operation.
61class InsertSliceOfTransferWriteOpFolder final
62 : public OpRewritePattern<tensor::InsertSliceOp> {
63public:
64 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
65
66 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
67 PatternRewriter &rewriter) const override;
68
69private:
70 static bool
71 doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
72};
73} // namespace
74
75template <typename XferOp, typename ExtractOrInsertOp>
77 RewriterBase &rewriter, XferOp xferOp,
78 ExtractOrInsertOp extractOrInsertSliceOp) {
79 if (xferOp.hasOutOfBoundsDim())
80 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
81 if (xferOp.getMask())
82 return rewriter.notifyMatchFailure(xferOp, "masked transfer");
83 if (!extractOrInsertSliceOp.hasUnitStride()) {
84 return rewriter.notifyMatchFailure(
85 xferOp, "non-1 stride insert/extract, requires keeping track of "
86 "strides, this may result in needing to insert "
87 "vector.insert_strided_slice/extract_strided_slice ops");
88 }
89 return success();
90}
91
92FailureOr<mlir::Value>
93TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
94 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
95 PatternRewriter &rewriter) const {
96 auto extractSliceOp =
97 getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
98 if (!extractSliceOp)
99 return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
100
101 LogicalResult preconditionResult =
103 extractSliceOp);
104 if (failed(preconditionResult))
105 return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
106
107 SmallVector<Value> indices(readOp.getIndices().begin(),
108 readOp.getIndices().end());
109 SmallVector<Value> sourceIndices;
111 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
112 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
113 indices, sourceIndices);
114
115 Operation *newOp = vector::TransferReadOp::create(
116 rewriter, readOp.getLoc(), readOp.getVectorType(),
117 extractSliceOp.getSource(), sourceIndices,
118 AffineMapAttr::get(expandDimsToRank(
119 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
120 extractSliceOp.getDroppedDims())),
121 readOp.getPadding(),
122 /*mask=*/Value(), readOp.getInBoundsAttr());
123 if (maskOp)
124 newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
125 return newOp->getResults()[0];
126}
127
128LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
129 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
130 auto writeOp = getTensorOperand(insertSliceOp)
131 .template getDefiningOp<vector::TransferWriteOp>();
132 if (!writeOp)
133 return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
134
135 LogicalResult preconditionResult =
137 insertSliceOp);
138 if (failed(preconditionResult))
139 return preconditionResult;
140
141 if (!doesTransferWriteCoverInsertSlice(writeOp))
142 return rewriter.notifyMatchFailure(
143 insertSliceOp, "transfer_write does not cover insert_slice");
144
145 SmallVector<Value> indices(writeOp.getIndices().begin(),
146 writeOp.getIndices().end());
147 SmallVector<Value> sourceIndices;
149 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
150 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
151 sourceIndices);
152
153 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
154 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
155 AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
156 insertSliceOp.getDestType().getRank(),
157 insertSliceOp.getDroppedDims())),
158 writeOp.getInBoundsAttr());
159
160 return success();
161}
162
163bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
164 vector::TransferWriteOp writeOp) {
165 if (writeOp.getShapedType().hasStaticShape())
166 return llvm::equal(writeOp.getVectorType().getShape(),
167 writeOp.getShapedType().getShape());
168
169 // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
170
171 return false;
172}
173
174template <typename OpTy>
177
178 LogicalResult matchAndRewrite(OpTy insertSliceOp,
179 PatternRewriter &rewriter) const override {
180 auto sourceInsertSliceOp =
181 insertSliceOp.getSource()
182 .template getDefiningOp<tensor::InsertSliceOp>();
183 if (!sourceInsertSliceOp)
184 return failure();
185
186 // TODO: relax unit stride assumption where possible.
187 if (!insertSliceOp.hasUnitStride()) {
188 return rewriter.notifyMatchFailure(insertSliceOp,
189 "requires unit strides");
190 }
191 if (!sourceInsertSliceOp.hasUnitStride()) {
192 return rewriter.notifyMatchFailure(sourceInsertSliceOp,
193 "requires unit strides");
194 }
195
196 int64_t srcDim = 0;
197 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
198 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
199 if (droppedDims[d])
200 continue;
201 if (insertSliceOp.getMixedSizes()[d] !=
202 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
203 return rewriter.notifyMatchFailure(
204 sourceInsertSliceOp,
205 "requires matching sizes to fold, otherwise a copy is needed");
206 }
207 }
208
209 // Resolve sizes according to dropped dims.
210 SmallVector<OpFoldResult> resolvedSizes;
211 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
212 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
213 // passed as the destination to the helper function.
214 affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
215 sourceInsertSliceOp.getMixedSizes(),
216 droppedDims, resolvedSizes);
217
218 // If we are inside a ParallelCombining region, temporarily set the
219 // insertion point outside: only ops of ParallelCombiningOpInterface are
220 // allowed in there.
221 if (isa<mlir::ParallelCombiningOpInterface>(insertSliceOp.getOperation())) {
222 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
223 }
224
225 // Resolve offsets according to source offsets and strides.
226 SmallVector<Value> resolvedOffsets;
227 // Note: the "insertSlice" case is symmetrical to the extract/subview case:
228 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
229 // passed as the destination to the helper function.
231 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
232 insertSliceOp.getMixedStrides(), droppedDims,
233 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
234
235 // Reset the insertion point.
236 rewriter.setInsertionPoint(insertSliceOp);
237 // Replace original op.
238 rewriter.replaceOpWithNewOp<OpTy>(
239 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
240 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
241 insertSliceOp.getMixedStrides());
242
243 return success();
244 }
245};
246
253
256 patterns.add<TransferReadOfExtractSliceOpFolder,
257 InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
258}
259
260//===----------------------------------------------------------------------===//
261// Pass registration
262//===----------------------------------------------------------------------===//
263
264namespace {
265
266struct FoldTensorSubsetOpsPass final
268 FoldTensorSubsetOpsPass> {
269 void runOnOperation() override;
270};
271
272} // namespace
273
274void FoldTensorSubsetOpsPass::runOnOperation() {
275 RewritePatternSet patterns(&getContext());
277 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
278}
return success()
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
b getContext())
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.