MLIR 22.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops for affine ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass contains affine-specif versions of the folding patterns for
10// memref.expand_shape, memref.collapse_shape, and memref.subview, since
11// those all need affine-specific handling that won't fit a general interface.
12//
13//===----------------------------------------------------------------------===//
14
21#include "mlir/IR/AffineMap.h"
23
24namespace mlir {
25namespace affine {
26#define GEN_PASS_DEF_AFFINEFOLDMEMREFALIASOPS
27#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
28} // namespace affine
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::affine;
33
34//===----------------------------------------------------------------------===//
35// Utility functions
36//===----------------------------------------------------------------------===//
37
38/// Given an AffineMap and a list of indices, apply the map to get the
39/// underlying indices (expanding the affine map).
41 Location loc, PatternRewriter &rewriter,
44 llvm::map_to_vector(indices, [](Value v) -> OpFoldResult { return v; }));
45 for (unsigned i : llvm::seq(0u, affineMap.getNumResults())) {
47 rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
48 result.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
49 }
50}
51
52//===----------------------------------------------------------------------===//
53// Patterns
54//===----------------------------------------------------------------------===//
55
56namespace {
57
58struct AffineLoadOpOfSubViewOpFolder final : OpRewritePattern<AffineLoadOp> {
59 using Base::Base;
60
61 LogicalResult matchAndRewrite(AffineLoadOp loadOp,
62 PatternRewriter &rewriter) const override {
63 auto subViewOp = loadOp.getMemref().getDefiningOp<memref::SubViewOp>();
64
65 if (!subViewOp)
66 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
67
68 SmallVector<Value> indices;
69 expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
70 loadOp.getLoc(), rewriter, indices);
71
72 SmallVector<Value> sourceIndices;
74 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
75 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
76 sourceIndices);
77
78 rewriter.replaceOpWithNewOp<AffineLoadOp>(loadOp, subViewOp.getSource(),
79 sourceIndices);
80 return success();
81 }
82};
83
84struct AffineLoadOpOfExpandShapeOpFolder final
85 : OpRewritePattern<AffineLoadOp> {
86 using Base::Base;
87
88 LogicalResult matchAndRewrite(AffineLoadOp loadOp,
89 PatternRewriter &rewriter) const override {
90 auto expandShapeOp =
91 loadOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
92
93 if (!expandShapeOp)
94 return failure();
95
96 SmallVector<Value> indices;
97 expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
98 loadOp.getLoc(), rewriter, indices);
99
100 SmallVector<Value> sourceIndices;
101 // affine.load guarantees that indexes start inbounds, which impacts if our
102 // linearization is `disjoint`.
104 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
105 /*startsInbounds=*/true);
106
107 rewriter.replaceOpWithNewOp<AffineLoadOp>(
108 loadOp, expandShapeOp.getViewSource(), sourceIndices);
109 return success();
110 }
111};
112
113struct AffineLoadOpOfCollapseShapeOpFolder final
114 : OpRewritePattern<AffineLoadOp> {
115 using Base::Base;
116
117 LogicalResult matchAndRewrite(AffineLoadOp loadOp,
118 PatternRewriter &rewriter) const override {
119 auto collapseShapeOp =
120 loadOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
121
122 if (!collapseShapeOp)
123 return failure();
124
125 SmallVector<Value> indices;
126 expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
127 loadOp.getLoc(), rewriter, indices);
128
129 SmallVector<Value> sourceIndices;
131 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices);
132
133 rewriter.replaceOpWithNewOp<AffineLoadOp>(
134 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
135 return success();
136 }
137};
138
139struct AffineStoreOpOfSubViewOpFolder final : OpRewritePattern<AffineStoreOp> {
140 using Base::Base;
141
142 LogicalResult matchAndRewrite(AffineStoreOp storeOp,
143 PatternRewriter &rewriter) const override {
144 auto subViewOp = storeOp.getMemref().getDefiningOp<memref::SubViewOp>();
145
146 if (!subViewOp)
147 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
148
149 // For affine ops, we need to apply the map to get the "actual" indices.
150 SmallVector<Value> indices;
151 expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
152 storeOp.getLoc(), rewriter, indices);
153
154 SmallVector<Value> sourceIndices;
156 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
157 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
158 sourceIndices);
159
160 rewriter.replaceOpWithNewOp<AffineStoreOp>(
161 storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
162 return success();
163 }
164};
165
166struct AffineStoreOpOfExpandShapeOpFolder final
167 : OpRewritePattern<AffineStoreOp> {
168 using Base::Base;
169
170 LogicalResult matchAndRewrite(AffineStoreOp storeOp,
171 PatternRewriter &rewriter) const override {
172 auto expandShapeOp =
173 storeOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
174
175 if (!expandShapeOp)
176 return failure();
177
178 SmallVector<Value> indices;
179 expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
180 storeOp.getLoc(), rewriter, indices);
181
182 SmallVector<Value> sourceIndices;
183 // affine.store guarantees that indexes start inbounds, which impacts if our
184 // linearization is `disjoint`.
186 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
187 /*startsInbounds=*/true);
188
189 rewriter.replaceOpWithNewOp<AffineStoreOp>(
190 storeOp, storeOp.getValueToStore(), expandShapeOp.getViewSource(),
191 sourceIndices);
192 return success();
193 }
194};
195
196struct AffineStoreOpOfCollapseShapeOpFolder final
197 : OpRewritePattern<AffineStoreOp> {
198 using Base::Base;
199
200 LogicalResult matchAndRewrite(AffineStoreOp storeOp,
201 PatternRewriter &rewriter) const override {
202 auto collapseShapeOp =
203 storeOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
204
205 if (!collapseShapeOp)
206 return failure();
207
208 // For affine ops, we need to apply the map to get the "actual" indices.
209 SmallVector<Value> indices;
210 expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
211 storeOp.getLoc(), rewriter, indices);
212
213 SmallVector<Value> sourceIndices;
215 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices);
216
217 rewriter.replaceOpWithNewOp<AffineStoreOp>(
218 storeOp, storeOp.getValueToStore(), collapseShapeOp.getViewSource(),
219 sourceIndices);
220 return success();
221 }
222};
223
224} // namespace
225
229 .add<AffineLoadOpOfSubViewOpFolder, AffineLoadOpOfExpandShapeOpFolder,
230 AffineLoadOpOfCollapseShapeOpFolder, AffineStoreOpOfSubViewOpFolder,
231 AffineStoreOpOfExpandShapeOpFolder,
232 AffineStoreOpOfCollapseShapeOpFolder>(patterns.getContext());
233}
234
235//===----------------------------------------------------------------------===//
236// Pass registration
237//===----------------------------------------------------------------------===//
238
239namespace {
240
241struct AffineFoldMemRefAliasOpsPass final
243 AffineFoldMemRefAliasOpsPass> {
244 void runOnOperation() override;
245};
246
247} // namespace
248
249void AffineFoldMemRefAliasOpsPass::runOnOperation() {
250 RewritePatternSet patterns(&getContext());
252 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
return success()
static void expandToUnderlyingIndices(AffineMap affineMap, ValueRange indices, Location loc, PatternRewriter &rewriter, SmallVectorImpl< Value > &result)
Given an AffineMap and a list of indices, apply the map to get the underlying indices (expanding the ...
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateAffineFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into affine load/store ops into patterns.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...