MLIR  19.0.0git
FoldTensorSubsetOps.cpp
Go to the documentation of this file.
1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold tensor subset ops with producer / consumers.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/AffineMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include <type_traits>
26 
27 namespace mlir {
28 namespace tensor {
29 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
30 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
31 } // namespace tensor
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 static Value getTensorOperand(vector::TransferReadOp op) {
37  return op.getSource();
38 }
39 
40 static Value getTensorOperand(tensor::InsertSliceOp op) {
41  return op.getSource();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // Patterns
46 //===----------------------------------------------------------------------===//
47 
48 namespace {
49 /// Merge extract_slice operation with load/transferRead operation.
50 class TransferReadOfExtractSliceOpFolder final
51  : public OpRewritePattern<vector::TransferReadOp> {
52 public:
54 
55  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
56  PatternRewriter &rewriter) const override;
57 };
58 
59 /// Merge insert_slice operation with store/transferWriteOp operation.
60 class InsertSliceOfTransferWriteOpFolder final
61  : public OpRewritePattern<tensor::InsertSliceOp> {
62 public:
64 
65  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
66  PatternRewriter &rewriter) const override;
67 };
68 } // namespace
69 
70 template <typename XferOp, typename ExtractOrInsertOp>
72  RewriterBase &rewriter, XferOp xferOp,
73  ExtractOrInsertOp extractOrInsertSliceOp) {
74  if (xferOp.hasOutOfBoundsDim())
75  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
76  if (xferOp.getMask())
77  return rewriter.notifyMatchFailure(xferOp, "masked transfer");
78  if (!extractOrInsertSliceOp.hasUnitStride()) {
79  return rewriter.notifyMatchFailure(
80  xferOp, "non-1 stride insert/extract, requires keeping track of "
81  "strides, this may result in needing to insert "
82  "vector.insert_strided_slice/extract_strided_slice ops");
83  }
84  return success();
85 }
86 
87 LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
88  vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
89  auto extractSliceOp =
90  getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
91  if (!extractSliceOp)
92  return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
93 
94  LogicalResult preconditionResult =
96  extractSliceOp);
97  if (failed(preconditionResult))
98  return preconditionResult;
99 
100  SmallVector<Value> indices(readOp.getIndices().begin(),
101  readOp.getIndices().end());
102  SmallVector<Value> sourceIndices;
104  rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
105  extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
106  indices, sourceIndices);
107 
108  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
109  readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
111  readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
112  extractSliceOp.getDroppedDims())),
113  readOp.getPadding(),
114  /*mask=*/Value(), readOp.getInBoundsAttr());
115 
116  return success();
117 }
118 
119 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
120  tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
121  auto writeOp = getTensorOperand(insertSliceOp)
122  .template getDefiningOp<vector::TransferWriteOp>();
123  if (!writeOp)
124  return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
125 
126  LogicalResult preconditionResult =
128  insertSliceOp);
129  if (failed(preconditionResult))
130  return preconditionResult;
131 
132  SmallVector<Value> indices(writeOp.getIndices().begin(),
133  writeOp.getIndices().end());
134  SmallVector<Value> sourceIndices;
136  rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
137  insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
138  sourceIndices);
139 
140  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
141  insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
142  AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
143  insertSliceOp.getDestType().getRank(),
144  insertSliceOp.getDroppedDims())),
145  writeOp.getInBoundsAttr());
146 
147  return success();
148 }
149 
150 template <typename OpTy>
153 
154  LogicalResult matchAndRewrite(OpTy insertSliceOp,
155  PatternRewriter &rewriter) const override {
156  auto sourceInsertSliceOp =
157  insertSliceOp.getSource()
158  .template getDefiningOp<tensor::InsertSliceOp>();
159  if (!sourceInsertSliceOp)
160  return failure();
161 
162  // TODO: relax unit stride assumption where possible.
163  if (!insertSliceOp.hasUnitStride()) {
164  return rewriter.notifyMatchFailure(insertSliceOp,
165  "requires unit strides");
166  }
167  if (!sourceInsertSliceOp.hasUnitStride()) {
168  return rewriter.notifyMatchFailure(sourceInsertSliceOp,
169  "requires unit strides");
170  }
171 
172  int64_t srcDim = 0;
173  llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
174  for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
175  if (droppedDims[d])
176  continue;
177  if (insertSliceOp.getMixedSizes()[d] !=
178  sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
179  return rewriter.notifyMatchFailure(
180  sourceInsertSliceOp,
181  "requires matching sizes to fold, otherwise a copy is needed");
182  }
183  }
184 
185  // Resolve sizes according to dropped dims.
186  SmallVector<OpFoldResult> resolvedSizes;
187  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
188  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
189  // passed as the destination to the helper function.
190  affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
191  sourceInsertSliceOp.getMixedSizes(),
192  droppedDims, resolvedSizes);
193 
194  // If we are inside an InParallel region, temporarily set the insertion
195  // point outside: only tensor.parallel_insert_slice ops are allowed in
196  // there.
197  if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
198  rewriter.setInsertionPoint(
199  insertSliceOp->template getParentOfType<scf::InParallelOp>());
200  }
201 
202  // Resolve offsets according to source offsets and strides.
203  SmallVector<Value> resolvedOffsets;
204  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
205  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
206  // passed as the destination to the helper function.
208  rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
209  insertSliceOp.getMixedStrides(), droppedDims,
210  sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
211 
212  // Reset the insertion point.
213  rewriter.setInsertionPoint(insertSliceOp);
214  // Replace original op.
215  rewriter.replaceOpWithNewOp<OpTy>(
216  insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
217  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
218  insertSliceOp.getMixedStrides());
219 
220  return success();
221  }
222 };
223 
228  patterns.getContext());
229 }
230 
232  RewritePatternSet &patterns) {
233  patterns.add<TransferReadOfExtractSliceOpFolder,
234  InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // Pass registration
239 //===----------------------------------------------------------------------===//
240 
241 namespace {
242 
243 struct FoldTensorSubsetOpsPass final
244  : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
245  void runOnOperation() override;
246 };
247 
248 } // namespace
249 
250 void FoldTensorSubsetOpsPass::runOnOperation() {
251  RewritePatternSet patterns(&getContext());
253  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
254 }
255 
256 std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
257  return std::make_unique<FoldTensorSubsetOpsPass>();
258 }
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
static MLIRContext * getContext(OpFoldResult val)
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
std::unique_ptr< Pass > createFoldTensorSubsetOpsPass()
Creates an instance of the tensor subset folding pass.
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:917
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357