MLIR  20.0.0git
FoldTensorSubsetOps.cpp
Go to the documentation of this file.
1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold tensor subset ops with producer / consumers.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/AffineMap.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include <type_traits>
28 
29 namespace mlir {
30 namespace tensor {
31 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
32 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
33 } // namespace tensor
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 static Value getTensorOperand(vector::TransferReadOp op) {
39  return op.getSource();
40 }
41 
42 static Value getTensorOperand(tensor::InsertSliceOp op) {
43  return op.getSource();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // Patterns
48 //===----------------------------------------------------------------------===//
49 
50 namespace {
51 /// Merge extract_slice operation with load/transferRead operation.
52 class TransferReadOfExtractSliceOpFolder final
53  : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
54 public:
55  using MaskableOpRewritePattern::MaskableOpRewritePattern;
56 
57  FailureOr<mlir::Value>
58  matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
59  vector::MaskingOpInterface maskOp,
60  PatternRewriter &rewriter) const override;
61 };
62 
63 /// Merge insert_slice operation with store/transferWriteOp operation.
64 class InsertSliceOfTransferWriteOpFolder final
65  : public OpRewritePattern<tensor::InsertSliceOp> {
66 public:
68 
69  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
70  PatternRewriter &rewriter) const override;
71 
72 private:
73  static bool
74  doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
75 };
76 } // namespace
77 
78 template <typename XferOp, typename ExtractOrInsertOp>
80  RewriterBase &rewriter, XferOp xferOp,
81  ExtractOrInsertOp extractOrInsertSliceOp) {
82  if (xferOp.hasOutOfBoundsDim())
83  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
84  if (xferOp.getMask())
85  return rewriter.notifyMatchFailure(xferOp, "masked transfer");
86  if (!extractOrInsertSliceOp.hasUnitStride()) {
87  return rewriter.notifyMatchFailure(
88  xferOp, "non-1 stride insert/extract, requires keeping track of "
89  "strides, this may result in needing to insert "
90  "vector.insert_strided_slice/extract_strided_slice ops");
91  }
92  return success();
93 }
94 
95 FailureOr<mlir::Value>
96 TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
97  vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
98  PatternRewriter &rewriter) const {
99  auto extractSliceOp =
100  getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
101  if (!extractSliceOp)
102  return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
103 
104  LogicalResult preconditionResult =
106  extractSliceOp);
107  if (failed(preconditionResult))
108  return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
109 
110  SmallVector<Value> indices(readOp.getIndices().begin(),
111  readOp.getIndices().end());
112  SmallVector<Value> sourceIndices;
114  rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
115  extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
116  indices, sourceIndices);
117 
118  Operation *newOp = rewriter.create<vector::TransferReadOp>(
119  readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
120  sourceIndices,
122  readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
123  extractSliceOp.getDroppedDims())),
124  readOp.getPadding(),
125  /*mask=*/Value(), readOp.getInBoundsAttr());
126  if (maskOp)
127  newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
128  return newOp->getResults()[0];
129 }
130 
131 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
132  tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
133  auto writeOp = getTensorOperand(insertSliceOp)
134  .template getDefiningOp<vector::TransferWriteOp>();
135  if (!writeOp)
136  return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
137 
138  LogicalResult preconditionResult =
140  insertSliceOp);
141  if (failed(preconditionResult))
142  return preconditionResult;
143 
144  if (!doesTransferWriteCoverInsertSlice(writeOp))
145  return rewriter.notifyMatchFailure(
146  insertSliceOp, "transfer_write does not cover insert_slice");
147 
148  SmallVector<Value> indices(writeOp.getIndices().begin(),
149  writeOp.getIndices().end());
150  SmallVector<Value> sourceIndices;
152  rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
153  insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
154  sourceIndices);
155 
156  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
157  insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
158  AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
159  insertSliceOp.getDestType().getRank(),
160  insertSliceOp.getDroppedDims())),
161  writeOp.getInBoundsAttr());
162 
163  return success();
164 }
165 
166 bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
167  vector::TransferWriteOp writeOp) {
168  if (writeOp.getShapedType().hasStaticShape())
169  return llvm::equal(writeOp.getVectorType().getShape(),
170  writeOp.getShapedType().getShape());
171 
172  // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
173 
174  return false;
175 }
176 
177 template <typename OpTy>
180 
181  LogicalResult matchAndRewrite(OpTy insertSliceOp,
182  PatternRewriter &rewriter) const override {
183  auto sourceInsertSliceOp =
184  insertSliceOp.getSource()
185  .template getDefiningOp<tensor::InsertSliceOp>();
186  if (!sourceInsertSliceOp)
187  return failure();
188 
189  // TODO: relax unit stride assumption where possible.
190  if (!insertSliceOp.hasUnitStride()) {
191  return rewriter.notifyMatchFailure(insertSliceOp,
192  "requires unit strides");
193  }
194  if (!sourceInsertSliceOp.hasUnitStride()) {
195  return rewriter.notifyMatchFailure(sourceInsertSliceOp,
196  "requires unit strides");
197  }
198 
199  int64_t srcDim = 0;
200  llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
201  for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
202  if (droppedDims[d])
203  continue;
204  if (insertSliceOp.getMixedSizes()[d] !=
205  sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
206  return rewriter.notifyMatchFailure(
207  sourceInsertSliceOp,
208  "requires matching sizes to fold, otherwise a copy is needed");
209  }
210  }
211 
212  // Resolve sizes according to dropped dims.
213  SmallVector<OpFoldResult> resolvedSizes;
214  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
215  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
216  // passed as the destination to the helper function.
217  affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
218  sourceInsertSliceOp.getMixedSizes(),
219  droppedDims, resolvedSizes);
220 
221  // If we are inside an InParallel region, temporarily set the insertion
222  // point outside: only tensor.parallel_insert_slice ops are allowed in
223  // there.
224  if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
225  rewriter.setInsertionPoint(
226  insertSliceOp->template getParentOfType<scf::InParallelOp>());
227  }
228 
229  // Resolve offsets according to source offsets and strides.
230  SmallVector<Value> resolvedOffsets;
231  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
232  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
233  // passed as the destination to the helper function.
235  rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
236  insertSliceOp.getMixedStrides(), droppedDims,
237  sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
238 
239  // Reset the insertion point.
240  rewriter.setInsertionPoint(insertSliceOp);
241  // Replace original op.
242  rewriter.replaceOpWithNewOp<OpTy>(
243  insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
244  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
245  insertSliceOp.getMixedStrides());
246 
247  return success();
248  }
249 };
250 
255  patterns.getContext());
256 }
257 
259  RewritePatternSet &patterns) {
260  patterns.add<TransferReadOfExtractSliceOpFolder,
261  InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // Pass registration
266 //===----------------------------------------------------------------------===//
267 
268 namespace {
269 
270 struct FoldTensorSubsetOpsPass final
271  : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
272  void runOnOperation() override;
273 };
274 
275 } // namespace
276 
277 void FoldTensorSubsetOpsPass::runOnOperation() {
278  RewritePatternSet patterns(&getContext());
280  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
281 }
282 
283 std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
284  return std::make_unique<FoldTensorSubsetOpsPass>();
285 }
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
static MLIRContext * getContext(OpFoldResult val)
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
std::unique_ptr< Pass > createFoldTensorSubsetOpsPass()
Creates an instance of the tensor subset folding pass.
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:953
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:157