MLIR 22.0.0git
BubbleUpExtractSlice.cpp
Go to the documentation of this file.
1//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns that transforms linalg.<op> +
10// tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11// the computation for the linalg op.
12//
13//===----------------------------------------------------------------------===//
14
19
20using namespace mlir;
21using namespace mlir::linalg;
22
23namespace {
24/// Bubble up extract_slice above Linalg operation.
25///
26/// A sequence of operations
27///
28/// ```mlir
29/// %0 = linalg.<op> ... arg0, arg1, ...
30/// %1 = tensor.extract_slice %0 ...
31/// ```
32///
33/// can be replaced with
34///
35/// ```mlir
36/// %0 = tensor.extract_slice %arg0
37/// %1 = tensor.extract_slice %arg1
38/// %2 = linalg.<op> ... %0, %1, ...
39/// ```
40///
41/// This results in the reduce computation of the linalg operation.
42///
43struct BubbleUpExtractSliceOpPattern
44 : OpRewritePattern<tensor::ExtractSliceOp> {
45 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
46
47 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
48 PatternRewriter &rewriter) const final {
49 Value source = sliceOp.getSource();
50 auto linalgOp = source.getDefiningOp<LinalgOp>();
51 if (!linalgOp) {
52 return rewriter.notifyMatchFailure(sliceOp,
53 "expected source to be linalg op");
54 }
55
56 // TODO: we might relax this if we want heuristics to detect that all uses
57 // are small portion of the output.
58 if (!linalgOp->hasOneUse()) {
59 return rewriter.notifyMatchFailure(sliceOp,
60 "expected single use of linalg op");
61 }
62
63 if (linalgOp.getNumDpsInits() != 1) {
64 return rewriter.notifyMatchFailure(sliceOp,
65 "expected single output of linalg op");
66 }
67
68 if (!linalgOp.hasPureTensorSemantics()) {
69 return rewriter.notifyMatchFailure(sliceOp,
70 "expected tensor of linalg op");
71 }
72
73 if (!sliceOp.hasUnitStride())
74 return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
75
76 if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
77 return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
78 }
79
80 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
81 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
82 if (!indexingMap.isProjectedPermutation()) {
83 return rewriter.notifyMatchFailure(
84 sliceOp, "expected a projected permutation for output");
85 }
86
87 auto linalgLoc = linalgOp.getLoc();
88 SmallVector<OpFoldResult> allShapeSizes =
89 linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
90 AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
91 if (!shapeSizesToLoopsMap) {
92 return rewriter.notifyMatchFailure(
93 linalgOp, "failed to get loops map from shape sizes");
94 }
95 SmallVector<OpFoldResult> sizeBounds =
97 rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
98
99 // The offsets and sizes from the slice operation only give you the tile
100 // size of the output. Use that compute the tile sizes and offsets of the
101 // loops. For loops not used to access the output, set the tile sizes to
102 // loop bounds and set the offset to 0.
103 SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
104 rewriter.getIndexAttr(0));
105 SmallVector<OpFoldResult> tileSizes = sizeBounds;
106 for (auto const &result : enumerate(indexingMap.getResults())) {
107 unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
108 tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
109 tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
110 }
111
112 SmallVector<Value> valuesToTile = linalgOp->getOperands();
113 SmallVector<Value> tiledOperands =
114 makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
115 tileOffsets, tileSizes, sizeBounds,
116 /*omitPartialTileCheck=*/true);
117
118 SmallVector<Type, 4> resultTensorTypes;
119 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable())
120 resultTensorTypes.push_back(
121 tiledOperands[opOperand.getOperandNumber()].getType());
122
123 Operation *newOp =
124 clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
125 rewriter.replaceOp(sliceOp, newOp->getResults());
126 return success();
127 }
128};
129} // namespace
130
133 auto *context = patterns.getContext();
134 patterns.add<BubbleUpExtractSliceOpPattern>(context);
135}
return success()
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_range getResults()
Definition Operation.h:415
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Patterns that are used to bubble up extract slice op above linalg op.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:1732
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...