MLIR  19.0.0git
BubbleUpExtractSlice.cpp
Go to the documentation of this file.
1 //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns that transforms linalg.<op> +
10 // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11 // the computation for the linalg op.
12 //
13 //===----------------------------------------------------------------------===//
14 
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 namespace {
27 /// Bubble up extract_slice above Linalg operation.
28 ///
29 /// A sequence of operations
30 ///
31 /// ```mlir
32 /// %0 = linalg.<op> ... arg0, arg1, ...
33 /// %1 = tensor.extract_slice %0 ...
34 /// ```
35 ///
36 /// can be replaced with
37 ///
38 /// ```mlir
39 /// %0 = tensor.extract_slice %arg0
40 /// %1 = tensor.extract_slice %arg1
41 /// %2 = linalg.<op> ... %0, %1, ...
42 /// ```
43 ///
44 /// This results in the reduce computation of the linalg operation.
45 ///
46 struct BubbleUpExtractSliceOpPattern
47  : OpRewritePattern<tensor::ExtractSliceOp> {
49 
50  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
51  PatternRewriter &rewriter) const final {
52  Value source = sliceOp.getSource();
53  auto linalgOp = source.getDefiningOp<LinalgOp>();
54  if (!linalgOp) {
55  return rewriter.notifyMatchFailure(sliceOp,
56  "expected source to be linalg op");
57  }
58 
59  // TODO: we might relax this if we want heuristics to detect that all uses
60  // are small portion of the output.
61  if (!linalgOp->hasOneUse()) {
62  return rewriter.notifyMatchFailure(sliceOp,
63  "expected single use of linalg op");
64  }
65 
66  if (linalgOp.getNumDpsInits() != 1) {
67  return rewriter.notifyMatchFailure(sliceOp,
68  "expected single output of linalg op");
69  }
70 
71  if (!linalgOp.hasPureTensorSemantics()) {
72  return rewriter.notifyMatchFailure(sliceOp,
73  "expected tensor of linalg op");
74  }
75 
76  if (!sliceOp.hasUnitStride())
77  return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
78 
79  if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80  return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
81  }
82 
83  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
84  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
85  if (!indexingMap.isProjectedPermutation()) {
86  return rewriter.notifyMatchFailure(
87  sliceOp, "expected a projected permutation for output");
88  }
89 
90  auto linalgLoc = linalgOp.getLoc();
91  SmallVector<OpFoldResult> allShapeSizes =
92  linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93  AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94  if (!shapeSizesToLoopsMap) {
95  return rewriter.notifyMatchFailure(
96  linalgOp, "failed to get loops map from shape sizes");
97  }
98  SmallVector<OpFoldResult> sizeBounds =
100  rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
101 
102  // The offsets and sizes from the slice operation only give you the tile
103  // size of the output. Use that compute the tile sizes and offsets of the
104  // loops. For loops not used to access the output, set the tile sizes to
105  // loop bounds and set the offset to 0.
106  SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
107  rewriter.getIndexAttr(0));
108  SmallVector<OpFoldResult> tileSizes = sizeBounds;
109  for (auto const &result : enumerate(indexingMap.getResults())) {
110  unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
111  tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
112  tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
113  }
114 
115  SmallVector<Value> valuesToTile = linalgOp->getOperands();
116  SmallVector<Value> tiledOperands =
117  makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
118  tileOffsets, tileSizes, sizeBounds,
119  /*omitPartialTileCheck=*/true);
120 
121  SmallVector<Type, 4> resultTensorTypes;
122  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable())
123  resultTensorTypes.push_back(
124  tiledOperands[opOperand.getOperandNumber()].getType());
125 
126  Operation *newOp =
127  clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
128  rewriter.replaceOp(sliceOp, newOp->getResults());
129  return success();
130  }
131 };
132 } // namespace
133 
135  RewritePatternSet &patterns) {
136  auto *context = patterns.getContext();
137  patterns.add<BubbleUpExtractSliceOpPattern>(context);
138 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:579
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Patterns that are used to bubble up extract slice op above linalg op.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition: Utils.cpp:829
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358