MLIR  22.0.0git
FoldTensorSubsetOps.cpp
Go to the documentation of this file.
1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold tensor subset ops with producer / consumers.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/AffineMap.h"
24 #include <type_traits>
25 
26 namespace mlir {
27 namespace tensor {
28 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPSPASS
29 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
30 } // namespace tensor
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 static Value getTensorOperand(vector::TransferReadOp op) {
36  return op.getBase();
37 }
38 
39 static Value getTensorOperand(tensor::InsertSliceOp op) {
40  return op.getSource();
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // Patterns
45 //===----------------------------------------------------------------------===//
46 
47 namespace {
48 /// Merge extract_slice operation with load/transferRead operation.
49 class TransferReadOfExtractSliceOpFolder final
50  : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
51 public:
52  using MaskableOpRewritePattern::MaskableOpRewritePattern;
53 
54  FailureOr<mlir::Value>
55  matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
56  vector::MaskingOpInterface maskOp,
57  PatternRewriter &rewriter) const override;
58 };
59 
60 /// Merge insert_slice operation with store/transferWriteOp operation.
61 class InsertSliceOfTransferWriteOpFolder final
62  : public OpRewritePattern<tensor::InsertSliceOp> {
63 public:
65 
66  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
67  PatternRewriter &rewriter) const override;
68 
69 private:
70  static bool
71  doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
72 };
73 } // namespace
74 
75 template <typename XferOp, typename ExtractOrInsertOp>
77  RewriterBase &rewriter, XferOp xferOp,
78  ExtractOrInsertOp extractOrInsertSliceOp) {
79  if (xferOp.hasOutOfBoundsDim())
80  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
81  if (xferOp.getMask())
82  return rewriter.notifyMatchFailure(xferOp, "masked transfer");
83  if (!extractOrInsertSliceOp.hasUnitStride()) {
84  return rewriter.notifyMatchFailure(
85  xferOp, "non-1 stride insert/extract, requires keeping track of "
86  "strides, this may result in needing to insert "
87  "vector.insert_strided_slice/extract_strided_slice ops");
88  }
89  return success();
90 }
91 
92 FailureOr<mlir::Value>
93 TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
94  vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
95  PatternRewriter &rewriter) const {
96  auto extractSliceOp =
97  getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
98  if (!extractSliceOp)
99  return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
100 
101  LogicalResult preconditionResult =
103  extractSliceOp);
104  if (failed(preconditionResult))
105  return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
106 
107  SmallVector<Value> indices(readOp.getIndices().begin(),
108  readOp.getIndices().end());
109  SmallVector<Value> sourceIndices;
111  rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
112  extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
113  indices, sourceIndices);
114 
115  Operation *newOp = vector::TransferReadOp::create(
116  rewriter, readOp.getLoc(), readOp.getVectorType(),
117  extractSliceOp.getSource(), sourceIndices,
119  readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
120  extractSliceOp.getDroppedDims())),
121  readOp.getPadding(),
122  /*mask=*/Value(), readOp.getInBoundsAttr());
123  if (maskOp)
124  newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
125  return newOp->getResults()[0];
126 }
127 
128 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
129  tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
130  auto writeOp = getTensorOperand(insertSliceOp)
131  .template getDefiningOp<vector::TransferWriteOp>();
132  if (!writeOp)
133  return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
134 
135  LogicalResult preconditionResult =
137  insertSliceOp);
138  if (failed(preconditionResult))
139  return preconditionResult;
140 
141  if (!doesTransferWriteCoverInsertSlice(writeOp))
142  return rewriter.notifyMatchFailure(
143  insertSliceOp, "transfer_write does not cover insert_slice");
144 
145  SmallVector<Value> indices(writeOp.getIndices().begin(),
146  writeOp.getIndices().end());
147  SmallVector<Value> sourceIndices;
149  rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
150  insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
151  sourceIndices);
152 
153  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
154  insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
155  AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
156  insertSliceOp.getDestType().getRank(),
157  insertSliceOp.getDroppedDims())),
158  writeOp.getInBoundsAttr());
159 
160  return success();
161 }
162 
163 bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
164  vector::TransferWriteOp writeOp) {
165  if (writeOp.getShapedType().hasStaticShape())
166  return llvm::equal(writeOp.getVectorType().getShape(),
167  writeOp.getShapedType().getShape());
168 
169  // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
170 
171  return false;
172 }
173 
174 template <typename OpTy>
177 
178  LogicalResult matchAndRewrite(OpTy insertSliceOp,
179  PatternRewriter &rewriter) const override {
180  auto sourceInsertSliceOp =
181  insertSliceOp.getSource()
182  .template getDefiningOp<tensor::InsertSliceOp>();
183  if (!sourceInsertSliceOp)
184  return failure();
185 
186  // TODO: relax unit stride assumption where possible.
187  if (!insertSliceOp.hasUnitStride()) {
188  return rewriter.notifyMatchFailure(insertSliceOp,
189  "requires unit strides");
190  }
191  if (!sourceInsertSliceOp.hasUnitStride()) {
192  return rewriter.notifyMatchFailure(sourceInsertSliceOp,
193  "requires unit strides");
194  }
195 
196  int64_t srcDim = 0;
197  llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
198  for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
199  if (droppedDims[d])
200  continue;
201  if (insertSliceOp.getMixedSizes()[d] !=
202  sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
203  return rewriter.notifyMatchFailure(
204  sourceInsertSliceOp,
205  "requires matching sizes to fold, otherwise a copy is needed");
206  }
207  }
208 
209  // Resolve sizes according to dropped dims.
210  SmallVector<OpFoldResult> resolvedSizes;
211  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
212  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
213  // passed as the destination to the helper function.
214  affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
215  sourceInsertSliceOp.getMixedSizes(),
216  droppedDims, resolvedSizes);
217 
218  // If we are inside an InParallel region, temporarily set the insertion
219  // point outside: only tensor.parallel_insert_slice ops are allowed in
220  // there.
221  if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
222  rewriter.setInsertionPoint(
223  insertSliceOp->template getParentOfType<scf::InParallelOp>());
224  }
225 
226  // Resolve offsets according to source offsets and strides.
227  SmallVector<Value> resolvedOffsets;
228  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
229  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
230  // passed as the destination to the helper function.
232  rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
233  insertSliceOp.getMixedStrides(), droppedDims,
234  sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
235 
236  // Reset the insertion point.
237  rewriter.setInsertionPoint(insertSliceOp);
238  // Replace original op.
239  rewriter.replaceOpWithNewOp<OpTy>(
240  insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
241  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
242  insertSliceOp.getMixedStrides());
243 
244  return success();
245  }
246 };
247 
252  patterns.getContext());
253 }
254 
257  patterns.add<TransferReadOfExtractSliceOpFolder,
258  InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // Pass registration
263 //===----------------------------------------------------------------------===//
264 
265 namespace {
266 
267 struct FoldTensorSubsetOpsPass final
268  : public tensor::impl::FoldTensorSubsetOpsPassBase<
269  FoldTensorSubsetOpsPass> {
270  void runOnOperation() override;
271 };
272 
273 } // namespace
274 
275 void FoldTensorSubsetOpsPass::runOnOperation() {
278  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
279 }
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
static MLIRContext * getContext(OpFoldResult val)
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:948
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163