MLIR 23.0.0git
ExtractAddressComputations.cpp
Go to the documentation of this file.
1//===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This transformation pass rewrites loading/storing from/to a memref with
10/// offsets into loading/storing from/to a subview and without any offset on
11/// the instruction itself.
12//
13//===----------------------------------------------------------------------===//
14
23#include "llvm/ADT/Repeated.h"
24
25using namespace mlir;
26
27namespace {
28
29//===----------------------------------------------------------------------===//
30// Helper functions for the `load base[off0...]`
31// => `load (subview base[off0...])[0...]` pattern.
32//===----------------------------------------------------------------------===//
33
34// Matches getFailureOrSrcMemRef specs for LoadOp.
35// \see LoadStoreLikeOpRewriter.
36static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
37 return loadOp.getMemRef();
38}
39
40// Matches rebuildOpFromAddressAndIndices specs for LoadOp.
41// \see LoadStoreLikeOpRewriter.
42static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
43 memref::LoadOp loadOp, Value srcMemRef,
45 Location loc = loadOp.getLoc();
46 return memref::LoadOp::create(rewriter, loc, srcMemRef, indices,
47 loadOp.getNontemporal());
48}
49
50// Matches getViewSizeForEachDim specs for LoadOp.
51// \see LoadStoreLikeOpRewriter.
53getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
54 MemRefType ldTy = loadOp.getMemRefType();
55 unsigned loadRank = ldTy.getRank();
56 return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
57}
58
59//===----------------------------------------------------------------------===//
60// Helper functions for the `store val, base[off0...]`
61// => `store val, (subview base[off0...])[0...]` pattern.
62//===----------------------------------------------------------------------===//
63
64// Matches getFailureOrSrcMemRef specs for StoreOp.
65// \see LoadStoreLikeOpRewriter.
66static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
67 return storeOp.getMemRef();
68}
69
70// Matches rebuildOpFromAddressAndIndices specs for StoreOp.
71// \see LoadStoreLikeOpRewriter.
72static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
73 memref::StoreOp storeOp, Value srcMemRef,
75 Location loc = storeOp.getLoc();
76 return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
77 srcMemRef, indices, storeOp.getNontemporal());
78}
79
80// Matches getViewSizeForEachDim specs for StoreOp.
81// \see LoadStoreLikeOpRewriter.
83getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
84 MemRefType ldTy = storeOp.getMemRefType();
85 unsigned loadRank = ldTy.getRank();
86 return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
87}
88
89//===----------------------------------------------------------------------===//
90// Helper functions for the `ldmatrix base[off0...]`
91// => `ldmatrix (subview base[off0...])[0...]` pattern.
92//===----------------------------------------------------------------------===//
93
94// Matches getFailureOrSrcMemRef specs for LdMatrixOp.
95// \see LoadStoreLikeOpRewriter.
96static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97 return ldMatrixOp.getSrcMemref();
98}
99
100// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
101// \see LoadStoreLikeOpRewriter.
102static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
103 nvgpu::LdMatrixOp ldMatrixOp,
104 Value srcMemRef,
106 Location loc = ldMatrixOp.getLoc();
107 return nvgpu::LdMatrixOp::create(
108 rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
110}
111
112//===----------------------------------------------------------------------===//
113// Helper functions for the `transfer_read base[off0...]`
114// => `transfer_read (subview base[off0...])[0...]` pattern.
115//===----------------------------------------------------------------------===//
116
117// Matches getFailureOrSrcMemRef specs for TransferReadOp.
118// \see LoadStoreLikeOpRewriter.
119template <typename TransferLikeOp>
120static FailureOr<Value>
121getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122 Value src = transferLikeOp.getBase();
123 if (isa<MemRefType>(src.getType()))
124 return src;
125 return failure();
126}
127
128// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
129// \see LoadStoreLikeOpRewriter.
130static vector::TransferReadOp
131rebuildTransferReadOp(RewriterBase &rewriter,
132 vector::TransferReadOp transferReadOp, Value srcMemRef,
134 Location loc = transferReadOp.getLoc();
135 return vector::TransferReadOp::create(
136 rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
139}
140
141//===----------------------------------------------------------------------===//
142// Helper functions for the `transfer_write base[off0...]`
143// => `transfer_write (subview base[off0...])[0...]` pattern.
144//===----------------------------------------------------------------------===//
145
146// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
147// \see LoadStoreLikeOpRewriter.
148static vector::TransferWriteOp
149rebuildTransferWriteOp(RewriterBase &rewriter,
150 vector::TransferWriteOp transferWriteOp, Value srcMemRef,
152 Location loc = transferWriteOp.getLoc();
153 return vector::TransferWriteOp::create(
154 rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices,
155 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156 transferWriteOp.getInBoundsAttr());
157}
158
159//===----------------------------------------------------------------------===//
160// Generic helper functions used as default implementation in
161// LoadStoreLikeOpRewriter.
162//===----------------------------------------------------------------------===//
163
164/// Helper function to get the src memref.
165/// It uses the already defined getFailureOrSrcMemRef but asserts
166/// that the source is a memref.
167template <typename LoadStoreLikeOp,
168 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
169static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171 assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
172 return *failureOrSrcMemRef;
173}
174
175/// Helper function to get the sizes of the resulting view.
176/// This function gets the sizes of the source memref then substracts the
177/// offsets used within \p loadStoreLikeOp. This gives the maximal (for
178/// inbound) sizes for the view.
179/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
180template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
182getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
183 LoadStoreLikeOp loadStoreLikeOp) {
184 Location loc = loadStoreLikeOp.getLoc();
185 auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
186 rewriter, loc, getSrcMemRef(loadStoreLikeOp));
188 extractStridedMetadataOp.getConstifiedMixedSizes();
190 getAsOpFoldResult(loadStoreLikeOp.getIndices());
191 SmallVector<OpFoldResult> finalSizes;
192
193 AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
194 AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
195
196 for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
197 finalSizes.push_back(affine::makeComposedFoldedAffineApply(
198 rewriter, loc, s0 - s1, {srcSize, indice}));
199 }
200 return finalSizes;
201}
202
203/// Rewrite a store/load-like op so that all its indices are zeros.
204/// E.g., %ld = memref.load %base[%off0]...[%offN]
205/// =>
206/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
207/// %ld = memref.load %new_base[0,..,0] :
208/// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
209///
210/// `getSrcMemRef` returns the source memref for the given load-like operation.
211///
212/// `getViewSizeForEachDim` returns the sizes of view that is going to feed
213/// new operation. This must return one size per dimension of the view.
214/// The sizes of the view needs to be at least as big as what is actually
215/// going to be accessed. Use the provided `loadStoreOp` to get the right
216/// sizes.
217///
218/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
219/// LoadStoreLikeOp that reads from srcMemRef[indices].
220/// The returned operation will be used to replace loadStoreOp.
221template <typename LoadStoreLikeOp,
222 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
223 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
224 RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
225 Value /*srcMemRef*/, ValueRange /*indices*/),
226 SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
227 RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
228 getGenericOpViewSizeForEachDim<
229 LoadStoreLikeOp,
230 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
231struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
232 using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;
233
234 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
235 PatternRewriter &rewriter) const override {
236 FailureOr<Value> failureOrSrcMemRef =
237 getFailureOrSrcMemRef(loadStoreLikeOp);
238 if (failed(failureOrSrcMemRef))
239 return rewriter.notifyMatchFailure(loadStoreLikeOp,
240 "source is not a memref");
241 Value srcMemRef = *failureOrSrcMemRef;
242 auto ldStTy = cast<MemRefType>(srcMemRef.getType());
243 unsigned loadStoreRank = ldStTy.getRank();
244 // Don't waste compile time if there is nothing to rewrite.
245 if (loadStoreRank == 0)
246 return rewriter.notifyMatchFailure(loadStoreLikeOp,
247 "0-D accesses don't need rewriting");
248
249 // If our load already has only zeros as indices there is nothing
250 // to do.
252 getAsOpFoldResult(loadStoreLikeOp.getIndices());
253 if (llvm::all_of(indices, isZeroInteger)) {
254 return rewriter.notifyMatchFailure(
255 loadStoreLikeOp, "no computation to extract: offsets are 0s");
256 }
257
258 // Create the array of ones of the right size.
259 SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
261 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
262 assert(sizes.size() == loadStoreRank &&
263 "Expected one size per load dimension");
264 Location loc = loadStoreLikeOp.getLoc();
265 // The subview inherits its strides from the original memref and will
266 // apply them properly to the input indices.
267 // Therefore the strides multipliers are simply ones.
268 auto subview =
269 memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef,
270 /*offsets=*/indices,
271 /*sizes=*/sizes, /*strides=*/ones);
272 // Rewrite the load/store with the subview as the base pointer.
273 Repeated<Value> zeros(loadStoreRank,
274 arith::ConstantIndexOp::create(rewriter, loc, 0));
275 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
276 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
277 rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
278 return success();
279 }
280};
281} // namespace
282
284 RewritePatternSet &patterns) {
285 patterns.add<
286 LoadStoreLikeOpRewriter<
287 memref::LoadOp,
288 /*getSrcMemRef=*/getLoadOpSrcMemRef,
289 /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
290 /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
291 LoadStoreLikeOpRewriter<
292 memref::StoreOp,
293 /*getSrcMemRef=*/getStoreOpSrcMemRef,
294 /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
295 /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
296 LoadStoreLikeOpRewriter<
297 nvgpu::LdMatrixOp,
298 /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
299 /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
300 LoadStoreLikeOpRewriter<
301 vector::TransferReadOp,
302 /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
303 /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
304 LoadStoreLikeOpRewriter<
305 vector::TransferWriteOp,
306 /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
307 /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
308 patterns.getContext());
309}
return success()
Base type for affine expression.
Definition AffineExpr.h:68
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...