MLIR  18.0.0git
ExtractAddressComputations.cpp
Go to the documentation of this file.
1 //===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This transformation pass rewrites loading/storing from/to a memref with
10 /// offsets into loading/storing from/to a subview and without any offset on
11 /// the instruction itself.
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // Helper functions for the `load base[off0...]`
30 // => `load (subview base[off0...])[0...]` pattern.
31 //===----------------------------------------------------------------------===//
32 
33 // Matches getFailureOrSrcMemRef specs for LoadOp.
34 // \see LoadStoreLikeOpRewriter.
35 static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36  return loadOp.getMemRef();
37 }
38 
39 // Matches rebuildOpFromAddressAndIndices specs for LoadOp.
40 // \see LoadStoreLikeOpRewriter.
41 static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
42  memref::LoadOp loadOp, Value srcMemRef,
43  ArrayRef<Value> indices) {
44  Location loc = loadOp.getLoc();
45  return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
46  loadOp.getNontemporal());
47 }
48 
49 // Matches getViewSizeForEachDim specs for LoadOp.
50 // \see LoadStoreLikeOpRewriter.
52 getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
53  MemRefType ldTy = loadOp.getMemRefType();
54  unsigned loadRank = ldTy.getRank();
55  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Helper functions for the `store val, base[off0...]`
60 // => `store val, (subview base[off0...])[0...]` pattern.
61 //===----------------------------------------------------------------------===//
62 
63 // Matches getFailureOrSrcMemRef specs for StoreOp.
64 // \see LoadStoreLikeOpRewriter.
65 static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66  return storeOp.getMemRef();
67 }
68 
69 // Matches rebuildOpFromAddressAndIndices specs for StoreOp.
70 // \see LoadStoreLikeOpRewriter.
71 static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
72  memref::StoreOp storeOp, Value srcMemRef,
73  ArrayRef<Value> indices) {
74  Location loc = storeOp.getLoc();
75  return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
76  srcMemRef, indices,
77  storeOp.getNontemporal());
78 }
79 
80 // Matches getViewSizeForEachDim specs for StoreOp.
81 // \see LoadStoreLikeOpRewriter.
83 getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
84  MemRefType ldTy = storeOp.getMemRefType();
85  unsigned loadRank = ldTy.getRank();
86  return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Helper functions for the `ldmatrix base[off0...]`
91 // => `ldmatrix (subview base[off0...])[0...]` pattern.
92 //===----------------------------------------------------------------------===//
93 
94 // Matches getFailureOrSrcMemRef specs for LdMatrixOp.
95 // \see LoadStoreLikeOpRewriter.
96 static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97  return ldMatrixOp.getSrcMemref();
98 }
99 
100 // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
101 // \see LoadStoreLikeOpRewriter.
102 static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
103  nvgpu::LdMatrixOp ldMatrixOp,
104  Value srcMemRef,
105  ArrayRef<Value> indices) {
106  Location loc = ldMatrixOp.getLoc();
107  return rewriter.create<nvgpu::LdMatrixOp>(
108  loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109  ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Helper functions for the `transfer_read base[off0...]`
114 // => `transfer_read (subview base[off0...])[0...]` pattern.
115 //===----------------------------------------------------------------------===//
116 
117 // Matches getFailureOrSrcMemRef specs for TransferReadOp.
118 // \see LoadStoreLikeOpRewriter.
119 template <typename TransferLikeOp>
120 static FailureOr<Value>
121 getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122  Value src = transferLikeOp.getSource();
123  if (isa<MemRefType>(src.getType()))
124  return src;
125  return failure();
126 }
127 
128 // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
129 // \see LoadStoreLikeOpRewriter.
130 static vector::TransferReadOp
131 rebuildTransferReadOp(RewriterBase &rewriter,
132  vector::TransferReadOp transferReadOp, Value srcMemRef,
133  ArrayRef<Value> indices) {
134  Location loc = transferReadOp.getLoc();
135  return rewriter.create<vector::TransferReadOp>(
136  loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137  transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138  transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // Helper functions for the `transfer_write base[off0...]`
143 // => `transfer_write (subview base[off0...])[0...]` pattern.
144 //===----------------------------------------------------------------------===//
145 
146 // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
147 // \see LoadStoreLikeOpRewriter.
148 static vector::TransferWriteOp
149 rebuildTransferWriteOp(RewriterBase &rewriter,
150  vector::TransferWriteOp transferWriteOp, Value srcMemRef,
151  ArrayRef<Value> indices) {
152  Location loc = transferWriteOp.getLoc();
153  return rewriter.create<vector::TransferWriteOp>(
154  loc, transferWriteOp.getValue(), srcMemRef, indices,
155  transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156  transferWriteOp.getInBoundsAttr());
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Generic helper functions used as default implementation in
161 // LoadStoreLikeOpRewriter.
162 //===----------------------------------------------------------------------===//
163 
164 /// Helper function to get the src memref.
165 /// It uses the already defined getFailureOrSrcMemRef but asserts
166 /// that the source is a memref.
167 template <typename LoadStoreLikeOp,
168  FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
169 static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170  FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171  assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
172  return *failureOrSrcMemRef;
173 }
174 
175 /// Helper function to get the sizes of the resulting view.
176 /// This function gets the sizes of the source memref then substracts the
177 /// offsets used within \p loadStoreLikeOp. This gives the maximal (for
178 /// inbound) sizes for the view.
179 /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
180 template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
182 getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
183  LoadStoreLikeOp loadStoreLikeOp) {
184  Location loc = loadStoreLikeOp.getLoc();
185  auto extractStridedMetadataOp =
186  rewriter.create<memref::ExtractStridedMetadataOp>(
187  loc, getSrcMemRef(loadStoreLikeOp));
188  SmallVector<OpFoldResult> srcSizes =
189  extractStridedMetadataOp.getConstifiedMixedSizes();
190  SmallVector<OpFoldResult> indices =
191  getAsOpFoldResult(loadStoreLikeOp.getIndices());
192  SmallVector<OpFoldResult> finalSizes;
193 
194  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
195  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
196 
197  for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
198  finalSizes.push_back(affine::makeComposedFoldedAffineApply(
199  rewriter, loc, s0 - s1, {srcSize, indice}));
200  }
201  return finalSizes;
202 }
203 
204 /// Rewrite a store/load-like op so that all its indices are zeros.
205 /// E.g., %ld = memref.load %base[%off0]...[%offN]
206 /// =>
207 /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
208 /// %ld = memref.load %new_base[0,..,0] :
209 /// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
210 ///
211 /// `getSrcMemRef` returns the source memref for the given load-like operation.
212 ///
213 /// `getViewSizeForEachDim` returns the sizes of view that is going to feed
214 /// new operation. This must return one size per dimension of the view.
215 /// The sizes of the view needs to be at least as big as what is actually
216 /// going to be accessed. Use the provided `loadStoreOp` to get the right
217 /// sizes.
218 ///
219 /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
220 /// LoadStoreLikeOp that reads from srcMemRef[indices].
221 /// The returned operation will be used to replace loadStoreOp.
222 template <typename LoadStoreLikeOp,
223  FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
224  LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
225  RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
226  Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
227  SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
228  RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
229  getGenericOpViewSizeForEachDim<
230  LoadStoreLikeOp,
231  getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
232 struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
234 
235  LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
236  PatternRewriter &rewriter) const override {
237  FailureOr<Value> failureOrSrcMemRef =
238  getFailureOrSrcMemRef(loadStoreLikeOp);
239  if (failed(failureOrSrcMemRef))
240  return rewriter.notifyMatchFailure(loadStoreLikeOp,
241  "source is not a memref");
242  Value srcMemRef = *failureOrSrcMemRef;
243  auto ldStTy = cast<MemRefType>(srcMemRef.getType());
244  unsigned loadStoreRank = ldStTy.getRank();
245  // Don't waste compile time if there is nothing to rewrite.
246  if (loadStoreRank == 0)
247  return rewriter.notifyMatchFailure(loadStoreLikeOp,
248  "0-D accesses don't need rewriting");
249 
250  // If our load already has only zeros as indices there is nothing
251  // to do.
252  SmallVector<OpFoldResult> indices =
253  getAsOpFoldResult(loadStoreLikeOp.getIndices());
254  if (std::all_of(indices.begin(), indices.end(),
255  [](const OpFoldResult &opFold) {
256  return isConstantIntValue(opFold, 0);
257  })) {
258  return rewriter.notifyMatchFailure(
259  loadStoreLikeOp, "no computation to extract: offsets are 0s");
260  }
261 
262  // Create the array of ones of the right size.
263  SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
265  getViewSizeForEachDim(rewriter, loadStoreLikeOp);
266  assert(sizes.size() == loadStoreRank &&
267  "Expected one size per load dimension");
268  Location loc = loadStoreLikeOp.getLoc();
269  // The subview inherits its strides from the original memref and will
270  // apply them properly to the input indices.
271  // Therefore the strides multipliers are simply ones.
272  auto subview =
273  rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
274  /*offsets=*/indices,
275  /*sizes=*/sizes, /*strides=*/ones);
276  // Rewrite the load/store with the subview as the base pointer.
277  SmallVector<Value> zeros(loadStoreRank,
278  rewriter.create<arith::ConstantIndexOp>(loc, 0));
279  LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
280  rewriter, loadStoreLikeOp, subview.getResult(), zeros);
281  rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
282  return success();
283  }
284 };
285 } // namespace
286 
288  RewritePatternSet &patterns) {
289  patterns.add<
290  LoadStoreLikeOpRewriter<
291  memref::LoadOp,
292  /*getSrcMemRef=*/getLoadOpSrcMemRef,
293  /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
294  /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
295  LoadStoreLikeOpRewriter<
296  memref::StoreOp,
297  /*getSrcMemRef=*/getStoreOpSrcMemRef,
298  /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
299  /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
300  LoadStoreLikeOpRewriter<
301  nvgpu::LdMatrixOp,
302  /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
303  /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
304  LoadStoreLikeOpRewriter<
305  vector::TransferReadOp,
306  /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
307  /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
308  LoadStoreLikeOpRewriter<
309  vector::TransferWriteOp,
310  /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
311  /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
312  patterns.getContext());
313 }
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:357
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357