MLIR  20.0.0git
FoldTensorSubsetOps.cpp
Go to the documentation of this file.
1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold tensor subset ops with producer / consumers.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/AffineMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <type_traits>
27 
28 namespace mlir {
29 namespace tensor {
30 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
31 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
32 } // namespace tensor
33 } // namespace mlir
34 
35 using namespace mlir;
36 
37 static Value getTensorOperand(vector::TransferReadOp op) {
38  return op.getSource();
39 }
40 
41 static Value getTensorOperand(tensor::InsertSliceOp op) {
42  return op.getSource();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // Patterns
47 //===----------------------------------------------------------------------===//
48 
49 namespace {
50 /// Merge extract_slice operation with load/transferRead operation.
51 class TransferReadOfExtractSliceOpFolder final
52  : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
53 public:
54  using MaskableOpRewritePattern::MaskableOpRewritePattern;
55 
56  FailureOr<mlir::Value>
57  matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
58  vector::MaskingOpInterface maskOp,
59  PatternRewriter &rewriter) const override;
60 };
61 
62 /// Merge insert_slice operation with store/transferWriteOp operation.
63 class InsertSliceOfTransferWriteOpFolder final
64  : public OpRewritePattern<tensor::InsertSliceOp> {
65 public:
67 
68  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
69  PatternRewriter &rewriter) const override;
70 };
71 } // namespace
72 
73 template <typename XferOp, typename ExtractOrInsertOp>
75  RewriterBase &rewriter, XferOp xferOp,
76  ExtractOrInsertOp extractOrInsertSliceOp) {
77  if (xferOp.hasOutOfBoundsDim())
78  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
79  if (xferOp.getMask())
80  return rewriter.notifyMatchFailure(xferOp, "masked transfer");
81  if (!extractOrInsertSliceOp.hasUnitStride()) {
82  return rewriter.notifyMatchFailure(
83  xferOp, "non-1 stride insert/extract, requires keeping track of "
84  "strides, this may result in needing to insert "
85  "vector.insert_strided_slice/extract_strided_slice ops");
86  }
87  return success();
88 }
89 
90 FailureOr<mlir::Value>
91 TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
92  vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
93  PatternRewriter &rewriter) const {
94  auto extractSliceOp =
95  getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
96  if (!extractSliceOp)
97  return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
98 
99  LogicalResult preconditionResult =
101  extractSliceOp);
102  if (failed(preconditionResult))
103  return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
104 
105  SmallVector<Value> indices(readOp.getIndices().begin(),
106  readOp.getIndices().end());
107  SmallVector<Value> sourceIndices;
109  rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
110  extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
111  indices, sourceIndices);
112 
113  Operation *newOp = rewriter.create<vector::TransferReadOp>(
114  readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
115  sourceIndices,
117  readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
118  extractSliceOp.getDroppedDims())),
119  readOp.getPadding(),
120  /*mask=*/Value(), readOp.getInBoundsAttr());
121  if (maskOp)
122  newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
123  return newOp->getResults()[0];
124 }
125 
126 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
127  tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
128  auto writeOp = getTensorOperand(insertSliceOp)
129  .template getDefiningOp<vector::TransferWriteOp>();
130  if (!writeOp)
131  return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
132 
133  LogicalResult preconditionResult =
135  insertSliceOp);
136  if (failed(preconditionResult))
137  return preconditionResult;
138 
139  SmallVector<Value> indices(writeOp.getIndices().begin(),
140  writeOp.getIndices().end());
141  SmallVector<Value> sourceIndices;
143  rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
144  insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
145  sourceIndices);
146 
147  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
148  insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
149  AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
150  insertSliceOp.getDestType().getRank(),
151  insertSliceOp.getDroppedDims())),
152  writeOp.getInBoundsAttr());
153 
154  return success();
155 }
156 
157 template <typename OpTy>
160 
161  LogicalResult matchAndRewrite(OpTy insertSliceOp,
162  PatternRewriter &rewriter) const override {
163  auto sourceInsertSliceOp =
164  insertSliceOp.getSource()
165  .template getDefiningOp<tensor::InsertSliceOp>();
166  if (!sourceInsertSliceOp)
167  return failure();
168 
169  // TODO: relax unit stride assumption where possible.
170  if (!insertSliceOp.hasUnitStride()) {
171  return rewriter.notifyMatchFailure(insertSliceOp,
172  "requires unit strides");
173  }
174  if (!sourceInsertSliceOp.hasUnitStride()) {
175  return rewriter.notifyMatchFailure(sourceInsertSliceOp,
176  "requires unit strides");
177  }
178 
179  int64_t srcDim = 0;
180  llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
181  for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
182  if (droppedDims[d])
183  continue;
184  if (insertSliceOp.getMixedSizes()[d] !=
185  sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
186  return rewriter.notifyMatchFailure(
187  sourceInsertSliceOp,
188  "requires matching sizes to fold, otherwise a copy is needed");
189  }
190  }
191 
192  // Resolve sizes according to dropped dims.
193  SmallVector<OpFoldResult> resolvedSizes;
194  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
195  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
196  // passed as the destination to the helper function.
197  affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
198  sourceInsertSliceOp.getMixedSizes(),
199  droppedDims, resolvedSizes);
200 
201  // If we are inside an InParallel region, temporarily set the insertion
202  // point outside: only tensor.parallel_insert_slice ops are allowed in
203  // there.
204  if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
205  rewriter.setInsertionPoint(
206  insertSliceOp->template getParentOfType<scf::InParallelOp>());
207  }
208 
209  // Resolve offsets according to source offsets and strides.
210  SmallVector<Value> resolvedOffsets;
211  // Note: the "insertSlice" case is symmetrical to the extract/subview case:
212  // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
213  // passed as the destination to the helper function.
215  rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
216  insertSliceOp.getMixedStrides(), droppedDims,
217  sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
218 
219  // Reset the insertion point.
220  rewriter.setInsertionPoint(insertSliceOp);
221  // Replace original op.
222  rewriter.replaceOpWithNewOp<OpTy>(
223  insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
224  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
225  insertSliceOp.getMixedStrides());
226 
227  return success();
228  }
229 };
230 
235  patterns.getContext());
236 }
237 
239  RewritePatternSet &patterns) {
240  patterns.add<TransferReadOfExtractSliceOpFolder,
241  InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Pass registration
246 //===----------------------------------------------------------------------===//
247 
248 namespace {
249 
250 struct FoldTensorSubsetOpsPass final
251  : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
252  void runOnOperation() override;
253 };
254 
255 } // namespace
256 
257 void FoldTensorSubsetOpsPass::runOnOperation() {
258  RewritePatternSet patterns(&getContext());
260  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
261 }
262 
263 std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
264  return std::make_unique<FoldTensorSubsetOpsPass>();
265 }
static Value getTensorOperand(vector::TransferReadOp op)
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(RewriterBase &rewriter, XferOp xferOp, ExtractOrInsertOp extractOrInsertSliceOp)
static MLIRContext * getContext(OpFoldResult val)
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
std::unique_ptr< Pass > createFoldTensorSubsetOpsPass()
Creates an instance of the tensor subset folding pass.
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:930
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:152