MLIR  18.0.0git
BubbleUpExtractSlice.cpp
Go to the documentation of this file.
1 //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns that transforms linalg.<op> +
10 // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11 // the computation for the linalg op.
12 //
13 //===----------------------------------------------------------------------===//
14 
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 namespace {
27 /// Bubble up extract_slice above Linalg operation.
28 ///
29 /// A sequence of operations
30 ///
31 /// ```mlir
32 /// %0 = linalg.<op> ... arg0, arg1, ...
33 /// %1 = tensor.extract_slice %0 ...
34 /// ```
35 ///
36 /// can be replaced with
37 ///
38 /// ```mlir
39 /// %0 = tensor.extract_slice %arg0
40 /// %1 = tensor.extract_slice %arg1
41 /// %2 = linalg.<op> ... %0, %1, ...
42 /// ```
43 ///
44 /// This results in the reduce computation of the linalg operation.
45 ///
46 struct BubbleUpExtractSliceOpPattern
47  : OpRewritePattern<tensor::ExtractSliceOp> {
49 
50  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
51  PatternRewriter &rewriter) const final {
52  Value source = sliceOp.getSource();
53  auto linalgOp = source.getDefiningOp<LinalgOp>();
54  if (!linalgOp) {
55  return rewriter.notifyMatchFailure(sliceOp,
56  "expected source to be linalg op");
57  }
58 
59  // TODO: we might relax this if we want heuristics to detect that all uses
60  // are small portion of the output.
61  if (!linalgOp->hasOneUse()) {
62  return rewriter.notifyMatchFailure(sliceOp,
63  "expected single use of linalg op");
64  }
65 
66  if (linalgOp.getNumDpsInits() != 1) {
67  return rewriter.notifyMatchFailure(sliceOp,
68  "expected single output of linalg op");
69  }
70 
71  if (!linalgOp.hasTensorSemantics()) {
72  return rewriter.notifyMatchFailure(sliceOp,
73  "expected tensor of linalg op");
74  }
75 
76  if (!sliceOp.hasUnitStride())
77  return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
78 
79  if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80  return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
81  }
82 
83  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
84  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
85  if (!indexingMap.isProjectedPermutation()) {
86  return rewriter.notifyMatchFailure(
87  sliceOp, "expected a projected permutation for output");
88  }
89 
90  auto linalgLoc = linalgOp.getLoc();
91  SmallVector<OpFoldResult> allShapeSizes =
92  linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93  AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94  if (!shapeSizesToLoopsMap) {
95  return rewriter.notifyMatchFailure(
96  linalgOp, "failed to get loops map from shape sizes");
97  }
98  SmallVector<OpFoldResult> sizeBounds =
100  rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
101 
102  // The offsets and sizes from the slice operation only give you the tile
103  // size of the output. Use that compute the tile sizes and offsets of the
104  // loops. For loops not used to access the output, set the tile sizes to
105  // loop bounds and set the offset to 0.
106  SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
107  rewriter.getIndexAttr(0));
108  SmallVector<OpFoldResult> tileSizes = sizeBounds;
109  for (auto const &result : enumerate(indexingMap.getResults())) {
110  unsigned position = result.value().cast<AffineDimExpr>().getPosition();
111  tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
112  tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
113  }
114 
115  SmallVector<Value> valuesToTile = linalgOp->getOperands();
116  SmallVector<Value> tiledOperands =
117  makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
118  tileOffsets, tileSizes, sizeBounds,
119  /*omitPartialTileCheck=*/true);
120 
121  SmallVector<Type, 4> resultTensorTypes;
122  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable())
123  resultTensorTypes.push_back(
124  tiledOperands[opOperand.getOperandNumber()].getType());
125 
126  Operation *newOp =
127  clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
128  rewriter.replaceOp(sliceOp, newOp->getResults());
129  return success();
130  }
131 };
132 } // namespace
133 
135  RewritePatternSet &patterns) {
136  auto *context = patterns.getContext();
137  patterns.add<BubbleUpExtractSliceOpPattern>(context);
138 }
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:534
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
This class represents an operand of an operation.
Definition: Value.h:261
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1321
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Patterns that are used to bubble up extract slice op above linalg op.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition: Utils.cpp:829
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357