MLIR  19.0.0git
ExtractAddressComputations.cpp
Go to the documentation of this file.
1 //===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This transformation pass rewrites loading/storing from/to a memref with
10 /// offsets into loading/storing from/to a subview and without any offset on
11 /// the instruction itself.
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // Helper functions for the `load base[off0...]`
30 // => `load (subview base[off0...])[0...]` pattern.
31 //===----------------------------------------------------------------------===//
32 
33 // Matches getFailureOrSrcMemRef specs for LoadOp.
34 // \see LoadStoreLikeOpRewriter.
35 static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36  return loadOp.getMemRef();
37 }
38 
39 // Matches rebuildOpFromAddressAndIndices specs for LoadOp.
40 // \see LoadStoreLikeOpRewriter.
41 static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
42  memref::LoadOp loadOp, Value srcMemRef,
43  ArrayRef<Value> indices) {
44  Location loc = loadOp.getLoc();
45  return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
46  loadOp.getNontemporal());
47 }
48 
49 // Matches getViewSizeForEachDim specs for LoadOp.
50 // \see LoadStoreLikeOpRewriter.
52 getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
53  MemRefType ldTy = loadOp.getMemRefType();
54  unsigned loadRank = ldTy.getRank();
55  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Helper functions for the `store val, base[off0...]`
60 // => `store val, (subview base[off0...])[0...]` pattern.
61 //===----------------------------------------------------------------------===//
62 
63 // Matches getFailureOrSrcMemRef specs for StoreOp.
64 // \see LoadStoreLikeOpRewriter.
65 static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66  return storeOp.getMemRef();
67 }
68 
69 // Matches rebuildOpFromAddressAndIndices specs for StoreOp.
70 // \see LoadStoreLikeOpRewriter.
71 static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
72  memref::StoreOp storeOp, Value srcMemRef,
73  ArrayRef<Value> indices) {
74  Location loc = storeOp.getLoc();
75  return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
76  srcMemRef, indices,
77  storeOp.getNontemporal());
78 }
79 
80 // Matches getViewSizeForEachDim specs for StoreOp.
81 // \see LoadStoreLikeOpRewriter.
83 getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
84  MemRefType ldTy = storeOp.getMemRefType();
85  unsigned loadRank = ldTy.getRank();
86  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Helper functions for the `ldmatrix base[off0...]`
91 // => `ldmatrix (subview base[off0...])[0...]` pattern.
92 //===----------------------------------------------------------------------===//
93 
94 // Matches getFailureOrSrcMemRef specs for LdMatrixOp.
95 // \see LoadStoreLikeOpRewriter.
96 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97  return ldMatrixOp.getSrcMemref();
98 }
99 
100 // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
101 // \see LoadStoreLikeOpRewriter.
102 static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
103  nvgpu::LdMatrixOp ldMatrixOp,
104  Value srcMemRef,
105  ArrayRef<Value> indices) {
106  Location loc = ldMatrixOp.getLoc();
107  return rewriter.create<nvgpu::LdMatrixOp>(
108  loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109  ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Helper functions for the `transfer_read base[off0...]`
114 // => `transfer_read (subview base[off0...])[0...]` pattern.
115 //===----------------------------------------------------------------------===//
116 
117 // Matches getFailureOrSrcMemRef specs for TransferReadOp.
118 // \see LoadStoreLikeOpRewriter.
119 template <typename TransferLikeOp>
120 static FailureOr<Value>
121 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122  Value src = transferLikeOp.getSource();
123  if (isa<MemRefType>(src.getType()))
124  return src;
125  return failure();
126 }
127 
128 // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
129 // \see LoadStoreLikeOpRewriter.
130 static vector::TransferReadOp
131 rebuildTransferReadOp(RewriterBase &rewriter,
132  vector::TransferReadOp transferReadOp, Value srcMemRef,
133  ArrayRef<Value> indices) {
134  Location loc = transferReadOp.getLoc();
135  return rewriter.create<vector::TransferReadOp>(
136  loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137  transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138  transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // Helper functions for the `transfer_write base[off0...]`
143 // => `transfer_write (subview base[off0...])[0...]` pattern.
144 //===----------------------------------------------------------------------===//
145 
146 // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
147 // \see LoadStoreLikeOpRewriter.
148 static vector::TransferWriteOp
149 rebuildTransferWriteOp(RewriterBase &rewriter,
150  vector::TransferWriteOp transferWriteOp, Value srcMemRef,
151  ArrayRef<Value> indices) {
152  Location loc = transferWriteOp.getLoc();
153  return rewriter.create<vector::TransferWriteOp>(
154  loc, transferWriteOp.getValue(), srcMemRef, indices,
155  transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156  transferWriteOp.getInBoundsAttr());
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Generic helper functions used as default implementation in
161 // LoadStoreLikeOpRewriter.
162 //===----------------------------------------------------------------------===//
163 
164 /// Helper function to get the src memref.
165 /// It uses the already defined getFailureOrSrcMemRef but asserts
166 /// that the source is a memref.
167 template <typename LoadStoreLikeOp,
168  FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
169 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170  FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171  assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
172  return *failureOrSrcMemRef;
173 }
174 
175 /// Helper function to get the sizes of the resulting view.
176 /// This function gets the sizes of the source memref then substracts the
177 /// offsets used within \p loadStoreLikeOp. This gives the maximal (for
178 /// inbound) sizes for the view.
179 /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
180 template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
182 getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
183  LoadStoreLikeOp loadStoreLikeOp) {
184  Location loc = loadStoreLikeOp.getLoc();
185  auto extractStridedMetadataOp =
186  rewriter.create<memref::ExtractStridedMetadataOp>(
187  loc, getSrcMemRef(loadStoreLikeOp));
188  SmallVector<OpFoldResult> srcSizes =
189  extractStridedMetadataOp.getConstifiedMixedSizes();
190  SmallVector<OpFoldResult> indices =
191  getAsOpFoldResult(loadStoreLikeOp.getIndices());
192  SmallVector<OpFoldResult> finalSizes;
193 
194  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
195  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
196 
197  for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
198  finalSizes.push_back(affine::makeComposedFoldedAffineApply(
199  rewriter, loc, s0 - s1, {srcSize, indice}));
200  }
201  return finalSizes;
202 }
203 
204 /// Rewrite a store/load-like op so that all its indices are zeros.
205 /// E.g., %ld = memref.load %base[%off0]...[%offN]
206 /// =>
207 /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
208 /// %ld = memref.load %new_base[0,..,0] :
209 /// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
210 ///
211 /// `getSrcMemRef` returns the source memref for the given load-like operation.
212 ///
213 /// `getViewSizeForEachDim` returns the sizes of view that is going to feed
214 /// new operation. This must return one size per dimension of the view.
215 /// The sizes of the view needs to be at least as big as what is actually
216 /// going to be accessed. Use the provided `loadStoreOp` to get the right
217 /// sizes.
218 ///
219 /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
220 /// LoadStoreLikeOp that reads from srcMemRef[indices].
221 /// The returned operation will be used to replace loadStoreOp.
222 template <typename LoadStoreLikeOp,
223  FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
224  LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
225  RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
226  Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
227  SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
228  RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
229  getGenericOpViewSizeForEachDim<
230  LoadStoreLikeOp,
231  getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
232 struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
234 
235  LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
236  PatternRewriter &rewriter) const override {
237  FailureOr<Value> failureOrSrcMemRef =
238  getFailureOrSrcMemRef(loadStoreLikeOp);
239  if (failed(failureOrSrcMemRef))
240  return rewriter.notifyMatchFailure(loadStoreLikeOp,
241  "source is not a memref");
242  Value srcMemRef = *failureOrSrcMemRef;
243  auto ldStTy = cast<MemRefType>(srcMemRef.getType());
244  unsigned loadStoreRank = ldStTy.getRank();
245  // Don't waste compile time if there is nothing to rewrite.
246  if (loadStoreRank == 0)
247  return rewriter.notifyMatchFailure(loadStoreLikeOp,
248  "0-D accesses don't need rewriting");
249 
250  // If our load already has only zeros as indices there is nothing
251  // to do.
252  SmallVector<OpFoldResult> indices =
253  getAsOpFoldResult(loadStoreLikeOp.getIndices());
254  if (std::all_of(indices.begin(), indices.end(),
255  [](const OpFoldResult &opFold) {
256  return isConstantIntValue(opFold, 0);
257  })) {
258  return rewriter.notifyMatchFailure(
259  loadStoreLikeOp, "no computation to extract: offsets are 0s");
260  }
261 
262  // Create the array of ones of the right size.
263  SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
265  getViewSizeForEachDim(rewriter, loadStoreLikeOp);
266  assert(sizes.size() == loadStoreRank &&
267  "Expected one size per load dimension");
268  Location loc = loadStoreLikeOp.getLoc();
269  // The subview inherits its strides from the original memref and will
270  // apply them properly to the input indices.
271  // Therefore the strides multipliers are simply ones.
272  auto subview =
273  rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
274  /*offsets=*/indices,
275  /*sizes=*/sizes, /*strides=*/ones);
276  // Rewrite the load/store with the subview as the base pointer.
277  SmallVector<Value> zeros(loadStoreRank,
278  rewriter.create<arith::ConstantIndexOp>(loc, 0));
279  LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
280  rewriter, loadStoreLikeOp, subview.getResult(), zeros);
281  rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
282  return success();
283  }
284 };
285 } // namespace
286 
288  RewritePatternSet &patterns) {
289  patterns.add<
290  LoadStoreLikeOpRewriter<
291  memref::LoadOp,
292  /*getSrcMemRef=*/getLoadOpSrcMemRef,
293  /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
294  /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
295  LoadStoreLikeOpRewriter<
296  memref::StoreOp,
297  /*getSrcMemRef=*/getStoreOpSrcMemRef,
298  /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
299  /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
300  LoadStoreLikeOpRewriter<
301  nvgpu::LdMatrixOp,
302  /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
303  /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
304  LoadStoreLikeOpRewriter<
305  vector::TransferReadOp,
306  /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
307  /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
308  LoadStoreLikeOpRewriter<
309  vector::TransferWriteOp,
310  /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
311  /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
312  patterns.getContext());
313 }
Base type for affine expression.
Definition: AffineExpr.h:69
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358