MLIR 23.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops for affine ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass contains affine-specif versions of the folding patterns for
10// memref.expand_shape, memref.collapse_shape, and memref.subview, since
11// those all need affine-specific handling that won't fit a general interface.
12//
13//===----------------------------------------------------------------------===//
14
21#include "mlir/IR/AffineMap.h"
23
24namespace mlir {
25namespace affine {
26#define GEN_PASS_DEF_AFFINEFOLDMEMREFALIASOPS
27#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
28} // namespace affine
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::affine;
33
34//===----------------------------------------------------------------------===//
35// Utility functions
36//===----------------------------------------------------------------------===//
37
38/// Given an AffineMap and a list of indices, apply the map to get the
39/// underlying indices (expanding the affine map).
41 Location loc, PatternRewriter &rewriter,
44 llvm::map_to_vector(indices, [](Value v) -> OpFoldResult { return v; }));
45 for (unsigned i : llvm::seq(0u, affineMap.getNumResults())) {
47 rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
48 result.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
49 }
50}
51
52//===----------------------------------------------------------------------===//
53// Patterns
54//===----------------------------------------------------------------------===//
55
56namespace {
57
58struct AffineLoadOpOfSubViewOpFolder final : OpRewritePattern<AffineLoadOp> {
59 using Base::Base;
60
61 LogicalResult matchAndRewrite(AffineLoadOp loadOp,
62 PatternRewriter &rewriter) const override {
63 auto subViewOp = loadOp.getMemref().getDefiningOp<memref::SubViewOp>();
64
65 if (!subViewOp)
66 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
67
68 SmallVector<Value> indices;
69 expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
70 loadOp.getLoc(), rewriter, indices);
71
72 SmallVector<Value> sourceIndices;
74 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
75 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
76 sourceIndices);
77
78 rewriter.replaceOpWithNewOp<AffineLoadOp>(loadOp, subViewOp.getSource(),
79 sourceIndices);
80 return success();
81 }
82};
83
84struct AffineLoadOpOfExpandShapeOpFolder final
85 : OpRewritePattern<AffineLoadOp> {
86 using Base::Base;
87
88 LogicalResult matchAndRewrite(AffineLoadOp loadOp,
89 PatternRewriter &rewriter) const override {
90 auto expandShapeOp =
91 loadOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
92
93 if (!expandShapeOp)
94 return failure();
95
96 SmallVector<Value> indices;
97 expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
98 loadOp.getLoc(), rewriter, indices);
99
100 SmallVector<Value> sourceIndices;
101 // affine.load guarantees that indexes start inbounds, which impacts if our
102 // linearization is `disjoint`.
104 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
105 /*startsInbounds=*/true);
106
107 rewriter.replaceOpWithNewOp<AffineLoadOp>(
108 loadOp, expandShapeOp.getViewSource(), sourceIndices);
109 return success();
110 }
111};
112
113struct AffineLoadOpOfCollapseShapeOpFolder final
114 : OpRewritePattern<AffineLoadOp> {
115 using Base::Base;
116
117 LogicalResult matchAndRewrite(AffineLoadOp loadOp,
118 PatternRewriter &rewriter) const override {
119 auto collapseShapeOp =
120 loadOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
121
122 if (!collapseShapeOp)
123 return failure();
124
125 SmallVector<Value> indices;
126 expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
127 loadOp.getLoc(), rewriter, indices);
128
129 SmallVector<Value> sourceIndices;
131 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices,
132 /*startsInbounds=*/true);
133
134 rewriter.replaceOpWithNewOp<AffineLoadOp>(
135 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
136 return success();
137 }
138};
139
140struct AffineStoreOpOfSubViewOpFolder final : OpRewritePattern<AffineStoreOp> {
141 using Base::Base;
142
143 LogicalResult matchAndRewrite(AffineStoreOp storeOp,
144 PatternRewriter &rewriter) const override {
145 auto subViewOp = storeOp.getMemref().getDefiningOp<memref::SubViewOp>();
146
147 if (!subViewOp)
148 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
149
150 // For affine ops, we need to apply the map to get the "actual" indices.
151 SmallVector<Value> indices;
152 expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
153 storeOp.getLoc(), rewriter, indices);
154
155 SmallVector<Value> sourceIndices;
157 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
158 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
159 sourceIndices);
160
161 rewriter.replaceOpWithNewOp<AffineStoreOp>(
162 storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
163 return success();
164 }
165};
166
167struct AffineStoreOpOfExpandShapeOpFolder final
168 : OpRewritePattern<AffineStoreOp> {
169 using Base::Base;
170
171 LogicalResult matchAndRewrite(AffineStoreOp storeOp,
172 PatternRewriter &rewriter) const override {
173 auto expandShapeOp =
174 storeOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
175
176 if (!expandShapeOp)
177 return failure();
178
179 SmallVector<Value> indices;
180 expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
181 storeOp.getLoc(), rewriter, indices);
182
183 SmallVector<Value> sourceIndices;
184 // affine.store guarantees that indexes start inbounds, which impacts if our
185 // linearization is `disjoint`.
187 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
188 /*startsInbounds=*/true);
189
190 rewriter.replaceOpWithNewOp<AffineStoreOp>(
191 storeOp, storeOp.getValueToStore(), expandShapeOp.getViewSource(),
192 sourceIndices);
193 return success();
194 }
195};
196
197struct AffineStoreOpOfCollapseShapeOpFolder final
198 : OpRewritePattern<AffineStoreOp> {
199 using Base::Base;
200
201 LogicalResult matchAndRewrite(AffineStoreOp storeOp,
202 PatternRewriter &rewriter) const override {
203 auto collapseShapeOp =
204 storeOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
205
206 if (!collapseShapeOp)
207 return failure();
208
209 // For affine ops, we need to apply the map to get the "actual" indices.
210 SmallVector<Value> indices;
211 expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
212 storeOp.getLoc(), rewriter, indices);
213
214 SmallVector<Value> sourceIndices;
216 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices,
217 /*startsInbounds=*/true);
218
219 rewriter.replaceOpWithNewOp<AffineStoreOp>(
220 storeOp, storeOp.getValueToStore(), collapseShapeOp.getViewSource(),
221 sourceIndices);
222 return success();
223 }
224};
225
226} // namespace
227
229 RewritePatternSet &patterns) {
230 patterns
231 .add<AffineLoadOpOfSubViewOpFolder, AffineLoadOpOfExpandShapeOpFolder,
232 AffineLoadOpOfCollapseShapeOpFolder, AffineStoreOpOfSubViewOpFolder,
233 AffineStoreOpOfExpandShapeOpFolder,
234 AffineStoreOpOfCollapseShapeOpFolder>(patterns.getContext());
235}
236
237//===----------------------------------------------------------------------===//
238// Pass registration
239//===----------------------------------------------------------------------===//
240
241namespace {
242
243struct AffineFoldMemRefAliasOpsPass final
245 AffineFoldMemRefAliasOpsPass> {
246 void runOnOperation() override;
247};
248
249} // namespace
250
251void AffineFoldMemRefAliasOpsPass::runOnOperation() {
252 RewritePatternSet patterns(&getContext());
254 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
return success()
static void expandToUnderlyingIndices(AffineMap affineMap, ValueRange indices, Location loc, PatternRewriter &rewriter, SmallVectorImpl< Value > &result)
Given an AffineMap and a list of indices, apply the map to get the underlying indices (expanding the ...
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateAffineFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into affine load/store ops into patterns.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...