MLIR  22.0.0git
ExtractAddressComputations.cpp
Go to the documentation of this file.
1 //===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This transformation pass rewrites loading/storing from/to a memref with
10 /// offsets into loading/storing from/to a subview and without any offset on
11 /// the instruction itself.
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // Helper functions for the `load base[off0...]`
30 // => `load (subview base[off0...])[0...]` pattern.
31 //===----------------------------------------------------------------------===//
32 
33 // Matches getFailureOrSrcMemRef specs for LoadOp.
34 // \see LoadStoreLikeOpRewriter.
35 static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36  return loadOp.getMemRef();
37 }
38 
39 // Matches rebuildOpFromAddressAndIndices specs for LoadOp.
40 // \see LoadStoreLikeOpRewriter.
41 static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
42  memref::LoadOp loadOp, Value srcMemRef,
43  ArrayRef<Value> indices) {
44  Location loc = loadOp.getLoc();
45  return memref::LoadOp::create(rewriter, loc, srcMemRef, indices,
46  loadOp.getNontemporal());
47 }
48 
49 // Matches getViewSizeForEachDim specs for LoadOp.
50 // \see LoadStoreLikeOpRewriter.
52 getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
53  MemRefType ldTy = loadOp.getMemRefType();
54  unsigned loadRank = ldTy.getRank();
55  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Helper functions for the `store val, base[off0...]`
60 // => `store val, (subview base[off0...])[0...]` pattern.
61 //===----------------------------------------------------------------------===//
62 
63 // Matches getFailureOrSrcMemRef specs for StoreOp.
64 // \see LoadStoreLikeOpRewriter.
65 static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66  return storeOp.getMemRef();
67 }
68 
69 // Matches rebuildOpFromAddressAndIndices specs for StoreOp.
70 // \see LoadStoreLikeOpRewriter.
71 static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
72  memref::StoreOp storeOp, Value srcMemRef,
73  ArrayRef<Value> indices) {
74  Location loc = storeOp.getLoc();
75  return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
76  srcMemRef, indices, storeOp.getNontemporal());
77 }
78 
79 // Matches getViewSizeForEachDim specs for StoreOp.
80 // \see LoadStoreLikeOpRewriter.
82 getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
83  MemRefType ldTy = storeOp.getMemRefType();
84  unsigned loadRank = ldTy.getRank();
85  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // Helper functions for the `ldmatrix base[off0...]`
90 // => `ldmatrix (subview base[off0...])[0...]` pattern.
91 //===----------------------------------------------------------------------===//
92 
93 // Matches getFailureOrSrcMemRef specs for LdMatrixOp.
94 // \see LoadStoreLikeOpRewriter.
95 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
96  return ldMatrixOp.getSrcMemref();
97 }
98 
99 // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
100 // \see LoadStoreLikeOpRewriter.
101 static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
102  nvgpu::LdMatrixOp ldMatrixOp,
103  Value srcMemRef,
104  ArrayRef<Value> indices) {
105  Location loc = ldMatrixOp.getLoc();
106  return nvgpu::LdMatrixOp::create(
107  rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
108  ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Helper functions for the `transfer_read base[off0...]`
113 // => `transfer_read (subview base[off0...])[0...]` pattern.
114 //===----------------------------------------------------------------------===//
115 
116 // Matches getFailureOrSrcMemRef specs for TransferReadOp.
117 // \see LoadStoreLikeOpRewriter.
118 template <typename TransferLikeOp>
119 static FailureOr<Value>
120 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
121  Value src = transferLikeOp.getBase();
122  if (isa<MemRefType>(src.getType()))
123  return src;
124  return failure();
125 }
126 
127 // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
128 // \see LoadStoreLikeOpRewriter.
129 static vector::TransferReadOp
130 rebuildTransferReadOp(RewriterBase &rewriter,
131  vector::TransferReadOp transferReadOp, Value srcMemRef,
132  ArrayRef<Value> indices) {
133  Location loc = transferReadOp.getLoc();
134  return vector::TransferReadOp::create(
135  rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices,
136  transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
137  transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // Helper functions for the `transfer_write base[off0...]`
142 // => `transfer_write (subview base[off0...])[0...]` pattern.
143 //===----------------------------------------------------------------------===//
144 
145 // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
146 // \see LoadStoreLikeOpRewriter.
147 static vector::TransferWriteOp
148 rebuildTransferWriteOp(RewriterBase &rewriter,
149  vector::TransferWriteOp transferWriteOp, Value srcMemRef,
150  ArrayRef<Value> indices) {
151  Location loc = transferWriteOp.getLoc();
152  return vector::TransferWriteOp::create(
153  rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices,
154  transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
155  transferWriteOp.getInBoundsAttr());
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Generic helper functions used as default implementation in
160 // LoadStoreLikeOpRewriter.
161 //===----------------------------------------------------------------------===//
162 
163 /// Helper function to get the src memref.
164 /// It uses the already defined getFailureOrSrcMemRef but asserts
165 /// that the source is a memref.
166 template <typename LoadStoreLikeOp,
167  FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
168 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
169  FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
170  assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
171  return *failureOrSrcMemRef;
172 }
173 
174 /// Helper function to get the sizes of the resulting view.
175 /// This function gets the sizes of the source memref then substracts the
176 /// offsets used within \p loadStoreLikeOp. This gives the maximal (for
177 /// inbound) sizes for the view.
178 /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
179 template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
181 getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
182  LoadStoreLikeOp loadStoreLikeOp) {
183  Location loc = loadStoreLikeOp.getLoc();
184  auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
185  rewriter, loc, getSrcMemRef(loadStoreLikeOp));
186  SmallVector<OpFoldResult> srcSizes =
187  extractStridedMetadataOp.getConstifiedMixedSizes();
188  SmallVector<OpFoldResult> indices =
189  getAsOpFoldResult(loadStoreLikeOp.getIndices());
190  SmallVector<OpFoldResult> finalSizes;
191 
192  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
193  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
194 
195  for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
196  finalSizes.push_back(affine::makeComposedFoldedAffineApply(
197  rewriter, loc, s0 - s1, {srcSize, indice}));
198  }
199  return finalSizes;
200 }
201 
202 /// Rewrite a store/load-like op so that all its indices are zeros.
203 /// E.g., %ld = memref.load %base[%off0]...[%offN]
204 /// =>
205 /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
206 /// %ld = memref.load %new_base[0,..,0] :
207 /// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
208 ///
209 /// `getSrcMemRef` returns the source memref for the given load-like operation.
210 ///
211 /// `getViewSizeForEachDim` returns the sizes of view that is going to feed
212 /// new operation. This must return one size per dimension of the view.
213 /// The sizes of the view needs to be at least as big as what is actually
214 /// going to be accessed. Use the provided `loadStoreOp` to get the right
215 /// sizes.
216 ///
217 /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
218 /// LoadStoreLikeOp that reads from srcMemRef[indices].
219 /// The returned operation will be used to replace loadStoreOp.
220 template <typename LoadStoreLikeOp,
221  FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
222  LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
223  RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
224  Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
225  SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
226  RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
227  getGenericOpViewSizeForEachDim<
228  LoadStoreLikeOp,
229  getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
230 struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
232 
233  LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
234  PatternRewriter &rewriter) const override {
235  FailureOr<Value> failureOrSrcMemRef =
236  getFailureOrSrcMemRef(loadStoreLikeOp);
237  if (failed(failureOrSrcMemRef))
238  return rewriter.notifyMatchFailure(loadStoreLikeOp,
239  "source is not a memref");
240  Value srcMemRef = *failureOrSrcMemRef;
241  auto ldStTy = cast<MemRefType>(srcMemRef.getType());
242  unsigned loadStoreRank = ldStTy.getRank();
243  // Don't waste compile time if there is nothing to rewrite.
244  if (loadStoreRank == 0)
245  return rewriter.notifyMatchFailure(loadStoreLikeOp,
246  "0-D accesses don't need rewriting");
247 
248  // If our load already has only zeros as indices there is nothing
249  // to do.
250  SmallVector<OpFoldResult> indices =
251  getAsOpFoldResult(loadStoreLikeOp.getIndices());
252  if (llvm::all_of(indices, isZeroInteger)) {
253  return rewriter.notifyMatchFailure(
254  loadStoreLikeOp, "no computation to extract: offsets are 0s");
255  }
256 
257  // Create the array of ones of the right size.
258  SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
260  getViewSizeForEachDim(rewriter, loadStoreLikeOp);
261  assert(sizes.size() == loadStoreRank &&
262  "Expected one size per load dimension");
263  Location loc = loadStoreLikeOp.getLoc();
264  // The subview inherits its strides from the original memref and will
265  // apply them properly to the input indices.
266  // Therefore the strides multipliers are simply ones.
267  auto subview =
268  memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef,
269  /*offsets=*/indices,
270  /*sizes=*/sizes, /*strides=*/ones);
271  // Rewrite the load/store with the subview as the base pointer.
272  SmallVector<Value> zeros(loadStoreRank,
273  arith::ConstantIndexOp::create(rewriter, loc, 0));
274  LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
275  rewriter, loadStoreLikeOp, subview.getResult(), zeros);
276  rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
277  return success();
278  }
279 };
280 } // namespace
281 
284  patterns.add<
285  LoadStoreLikeOpRewriter<
286  memref::LoadOp,
287  /*getSrcMemRef=*/getLoadOpSrcMemRef,
288  /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
289  /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
290  LoadStoreLikeOpRewriter<
291  memref::StoreOp,
292  /*getSrcMemRef=*/getStoreOpSrcMemRef,
293  /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
294  /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
295  LoadStoreLikeOpRewriter<
296  nvgpu::LdMatrixOp,
297  /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
298  /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
299  LoadStoreLikeOpRewriter<
300  vector::TransferReadOp,
301  /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
302  /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
303  LoadStoreLikeOpRewriter<
304  vector::TransferWriteOp,
305  /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
306  /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
307  patterns.getContext());
308 }
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314