MLIR  22.0.0git
BubbleUpExtractSlice.cpp
Go to the documentation of this file.
1 //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns that transforms linalg.<op> +
10 // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11 // the computation for the linalg op.
12 //
13 //===----------------------------------------------------------------------===//
14 
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 namespace {
24 /// Bubble up extract_slice above Linalg operation.
25 ///
26 /// A sequence of operations
27 ///
28 /// ```mlir
29 /// %0 = linalg.<op> ... arg0, arg1, ...
30 /// %1 = tensor.extract_slice %0 ...
31 /// ```
32 ///
33 /// can be replaced with
34 ///
35 /// ```mlir
36 /// %0 = tensor.extract_slice %arg0
37 /// %1 = tensor.extract_slice %arg1
38 /// %2 = linalg.<op> ... %0, %1, ...
39 /// ```
40 ///
41 /// This results in the reduce computation of the linalg operation.
42 ///
43 struct BubbleUpExtractSliceOpPattern
44  : OpRewritePattern<tensor::ExtractSliceOp> {
46 
47  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
48  PatternRewriter &rewriter) const final {
49  Value source = sliceOp.getSource();
50  auto linalgOp = source.getDefiningOp<LinalgOp>();
51  if (!linalgOp) {
52  return rewriter.notifyMatchFailure(sliceOp,
53  "expected source to be linalg op");
54  }
55 
56  // TODO: we might relax this if we want heuristics to detect that all uses
57  // are small portion of the output.
58  if (!linalgOp->hasOneUse()) {
59  return rewriter.notifyMatchFailure(sliceOp,
60  "expected single use of linalg op");
61  }
62 
63  if (linalgOp.getNumDpsInits() != 1) {
64  return rewriter.notifyMatchFailure(sliceOp,
65  "expected single output of linalg op");
66  }
67 
68  if (!linalgOp.hasPureTensorSemantics()) {
69  return rewriter.notifyMatchFailure(sliceOp,
70  "expected tensor of linalg op");
71  }
72 
73  if (!sliceOp.hasUnitStride())
74  return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
75 
76  if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
77  return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
78  }
79 
80  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
81  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
82  if (!indexingMap.isProjectedPermutation()) {
83  return rewriter.notifyMatchFailure(
84  sliceOp, "expected a projected permutation for output");
85  }
86 
87  auto linalgLoc = linalgOp.getLoc();
88  SmallVector<OpFoldResult> allShapeSizes =
89  linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
90  AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
91  if (!shapeSizesToLoopsMap) {
92  return rewriter.notifyMatchFailure(
93  linalgOp, "failed to get loops map from shape sizes");
94  }
95  SmallVector<OpFoldResult> sizeBounds =
97  rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
98 
99  // The offsets and sizes from the slice operation only give you the tile
100  // size of the output. Use that compute the tile sizes and offsets of the
101  // loops. For loops not used to access the output, set the tile sizes to
102  // loop bounds and set the offset to 0.
103  SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
104  rewriter.getIndexAttr(0));
105  SmallVector<OpFoldResult> tileSizes = sizeBounds;
106  for (auto const &result : enumerate(indexingMap.getResults())) {
107  unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
108  tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
109  tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
110  }
111 
112  SmallVector<Value> valuesToTile = linalgOp->getOperands();
113  SmallVector<Value> tiledOperands =
114  makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
115  tileOffsets, tileSizes, sizeBounds,
116  /*omitPartialTileCheck=*/true);
117 
118  SmallVector<Type, 4> resultTensorTypes;
119  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable())
120  resultTensorTypes.push_back(
121  tiledOperands[opOperand.getOperandNumber()].getType());
122 
123  Operation *newOp =
124  clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
125  rewriter.replaceOp(sliceOp, newOp->getResults());
126  return success();
127  }
128 };
129 } // namespace
130 
133  auto *context = patterns.getContext();
134  patterns.add<BubbleUpExtractSliceOpPattern>(context);
135 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1376
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Patterns that are used to bubble up extract slice op above linalg op.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition: Utils.cpp:862
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314