MLIR 22.0.0git
ExtractAddressComputations.cpp
Go to the documentation of this file.
1//===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This transformation pass rewrites loading/storing from/to a memref with
10/// offsets into loading/storing from/to a subview and without any offset on
11/// the instruction itself.
12//
13//===----------------------------------------------------------------------===//
14
23
24using namespace mlir;
25
26namespace {
27
28//===----------------------------------------------------------------------===//
29// Helper functions for the `load base[off0...]`
30// => `load (subview base[off0...])[0...]` pattern.
31//===----------------------------------------------------------------------===//
32
33// Matches getFailureOrSrcMemRef specs for LoadOp.
34// \see LoadStoreLikeOpRewriter.
35static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36 return loadOp.getMemRef();
37}
38
39// Matches rebuildOpFromAddressAndIndices specs for LoadOp.
40// \see LoadStoreLikeOpRewriter.
41static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
42 memref::LoadOp loadOp, Value srcMemRef,
44 Location loc = loadOp.getLoc();
45 return memref::LoadOp::create(rewriter, loc, srcMemRef, indices,
46 loadOp.getNontemporal());
47}
48
49// Matches getViewSizeForEachDim specs for LoadOp.
50// \see LoadStoreLikeOpRewriter.
52getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
53 MemRefType ldTy = loadOp.getMemRefType();
54 unsigned loadRank = ldTy.getRank();
55 return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
56}
57
58//===----------------------------------------------------------------------===//
59// Helper functions for the `store val, base[off0...]`
60// => `store val, (subview base[off0...])[0...]` pattern.
61//===----------------------------------------------------------------------===//
62
63// Matches getFailureOrSrcMemRef specs for StoreOp.
64// \see LoadStoreLikeOpRewriter.
65static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66 return storeOp.getMemRef();
67}
68
69// Matches rebuildOpFromAddressAndIndices specs for StoreOp.
70// \see LoadStoreLikeOpRewriter.
71static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
72 memref::StoreOp storeOp, Value srcMemRef,
74 Location loc = storeOp.getLoc();
75 return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
76 srcMemRef, indices, storeOp.getNontemporal());
77}
78
79// Matches getViewSizeForEachDim specs for StoreOp.
80// \see LoadStoreLikeOpRewriter.
82getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
83 MemRefType ldTy = storeOp.getMemRefType();
84 unsigned loadRank = ldTy.getRank();
85 return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
86}
87
88//===----------------------------------------------------------------------===//
89// Helper functions for the `ldmatrix base[off0...]`
90// => `ldmatrix (subview base[off0...])[0...]` pattern.
91//===----------------------------------------------------------------------===//
92
93// Matches getFailureOrSrcMemRef specs for LdMatrixOp.
94// \see LoadStoreLikeOpRewriter.
95static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
96 return ldMatrixOp.getSrcMemref();
97}
98
99// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
100// \see LoadStoreLikeOpRewriter.
101static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
102 nvgpu::LdMatrixOp ldMatrixOp,
103 Value srcMemRef,
105 Location loc = ldMatrixOp.getLoc();
106 return nvgpu::LdMatrixOp::create(
107 rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
108 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
109}
110
111//===----------------------------------------------------------------------===//
112// Helper functions for the `transfer_read base[off0...]`
113// => `transfer_read (subview base[off0...])[0...]` pattern.
114//===----------------------------------------------------------------------===//
115
116// Matches getFailureOrSrcMemRef specs for TransferReadOp.
117// \see LoadStoreLikeOpRewriter.
118template <typename TransferLikeOp>
119static FailureOr<Value>
120getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
121 Value src = transferLikeOp.getBase();
122 if (isa<MemRefType>(src.getType()))
123 return src;
124 return failure();
125}
126
127// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
128// \see LoadStoreLikeOpRewriter.
129static vector::TransferReadOp
130rebuildTransferReadOp(RewriterBase &rewriter,
131 vector::TransferReadOp transferReadOp, Value srcMemRef,
133 Location loc = transferReadOp.getLoc();
134 return vector::TransferReadOp::create(
135 rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices,
136 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
137 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
138}
139
140//===----------------------------------------------------------------------===//
141// Helper functions for the `transfer_write base[off0...]`
142// => `transfer_write (subview base[off0...])[0...]` pattern.
143//===----------------------------------------------------------------------===//
144
145// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
146// \see LoadStoreLikeOpRewriter.
147static vector::TransferWriteOp
148rebuildTransferWriteOp(RewriterBase &rewriter,
149 vector::TransferWriteOp transferWriteOp, Value srcMemRef,
151 Location loc = transferWriteOp.getLoc();
152 return vector::TransferWriteOp::create(
153 rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices,
154 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
155 transferWriteOp.getInBoundsAttr());
156}
157
158//===----------------------------------------------------------------------===//
159// Generic helper functions used as default implementation in
160// LoadStoreLikeOpRewriter.
161//===----------------------------------------------------------------------===//
162
163/// Helper function to get the src memref.
164/// It uses the already defined getFailureOrSrcMemRef but asserts
165/// that the source is a memref.
166template <typename LoadStoreLikeOp,
167 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
168static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
169 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
170 assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
171 return *failureOrSrcMemRef;
172}
173
174/// Helper function to get the sizes of the resulting view.
175/// This function gets the sizes of the source memref then substracts the
176/// offsets used within \p loadStoreLikeOp. This gives the maximal (for
177/// inbound) sizes for the view.
178/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
179template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
181getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
182 LoadStoreLikeOp loadStoreLikeOp) {
183 Location loc = loadStoreLikeOp.getLoc();
184 auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
185 rewriter, loc, getSrcMemRef(loadStoreLikeOp));
187 extractStridedMetadataOp.getConstifiedMixedSizes();
189 getAsOpFoldResult(loadStoreLikeOp.getIndices());
190 SmallVector<OpFoldResult> finalSizes;
191
192 AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
193 AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
194
195 for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
196 finalSizes.push_back(affine::makeComposedFoldedAffineApply(
197 rewriter, loc, s0 - s1, {srcSize, indice}));
198 }
199 return finalSizes;
200}
201
202/// Rewrite a store/load-like op so that all its indices are zeros.
203/// E.g., %ld = memref.load %base[%off0]...[%offN]
204/// =>
205/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
206/// %ld = memref.load %new_base[0,..,0] :
207/// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
208///
209/// `getSrcMemRef` returns the source memref for the given load-like operation.
210///
211/// `getViewSizeForEachDim` returns the sizes of view that is going to feed
212/// new operation. This must return one size per dimension of the view.
213/// The sizes of the view needs to be at least as big as what is actually
214/// going to be accessed. Use the provided `loadStoreOp` to get the right
215/// sizes.
216///
217/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
218/// LoadStoreLikeOp that reads from srcMemRef[indices].
219/// The returned operation will be used to replace loadStoreOp.
220template <typename LoadStoreLikeOp,
221 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
222 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
223 RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
224 Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
225 SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
226 RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
227 getGenericOpViewSizeForEachDim<
228 LoadStoreLikeOp,
229 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
230struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
231 using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;
232
233 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
234 PatternRewriter &rewriter) const override {
235 FailureOr<Value> failureOrSrcMemRef =
236 getFailureOrSrcMemRef(loadStoreLikeOp);
237 if (failed(failureOrSrcMemRef))
238 return rewriter.notifyMatchFailure(loadStoreLikeOp,
239 "source is not a memref");
240 Value srcMemRef = *failureOrSrcMemRef;
241 auto ldStTy = cast<MemRefType>(srcMemRef.getType());
242 unsigned loadStoreRank = ldStTy.getRank();
243 // Don't waste compile time if there is nothing to rewrite.
244 if (loadStoreRank == 0)
245 return rewriter.notifyMatchFailure(loadStoreLikeOp,
246 "0-D accesses don't need rewriting");
247
248 // If our load already has only zeros as indices there is nothing
249 // to do.
251 getAsOpFoldResult(loadStoreLikeOp.getIndices());
252 if (llvm::all_of(indices, isZeroInteger)) {
253 return rewriter.notifyMatchFailure(
254 loadStoreLikeOp, "no computation to extract: offsets are 0s");
255 }
256
257 // Create the array of ones of the right size.
258 SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
260 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
261 assert(sizes.size() == loadStoreRank &&
262 "Expected one size per load dimension");
263 Location loc = loadStoreLikeOp.getLoc();
264 // The subview inherits its strides from the original memref and will
265 // apply them properly to the input indices.
266 // Therefore the strides multipliers are simply ones.
267 auto subview =
268 memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef,
269 /*offsets=*/indices,
270 /*sizes=*/sizes, /*strides=*/ones);
271 // Rewrite the load/store with the subview as the base pointer.
272 SmallVector<Value> zeros(loadStoreRank,
273 arith::ConstantIndexOp::create(rewriter, loc, 0));
274 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
275 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
276 rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
277 return success();
278 }
279};
280} // namespace
281
284 patterns.add<
285 LoadStoreLikeOpRewriter<
286 memref::LoadOp,
287 /*getSrcMemRef=*/getLoadOpSrcMemRef,
288 /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
289 /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
290 LoadStoreLikeOpRewriter<
291 memref::StoreOp,
292 /*getSrcMemRef=*/getStoreOpSrcMemRef,
293 /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
294 /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
295 LoadStoreLikeOpRewriter<
296 nvgpu::LdMatrixOp,
297 /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
298 /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
299 LoadStoreLikeOpRewriter<
300 vector::TransferReadOp,
301 /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
302 /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
303 LoadStoreLikeOpRewriter<
304 vector::TransferWriteOp,
305 /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
306 /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
307 patterns.getContext());
308}
return success()
Base type for affine expression.
Definition AffineExpr.h:68
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...