MLIR 23.0.0git
ExtractAddressComputations.cpp
Go to the documentation of this file.
1//===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This transformation pass rewrites memory access operations with offsets into
10/// accesses through a subview and without any offset on the access operation
11/// itself.
12//
13//===----------------------------------------------------------------------===//
14
23
24using namespace mlir;
25
26namespace {
27
28//===----------------------------------------------------------------------===//
29// Helper functions for the `access base[off0...]`
30// => `access (subview base[off0...])[0...]` pattern.
31//===----------------------------------------------------------------------===//
32
33/// Returns true if every index is zero.
35 return llvm::all_of(getAsOpFoldResult(indices), isZeroInteger);
36}
37
38/// Get the remaining size in each dimension - that is, the size of the memref
39/// dimension minus the index. Used to preserve in_bounds behavior for
40/// transfer_read/write.
41static SmallVector<OpFoldResult> getRemainingSizes(RewriterBase &rewriter,
42 Location loc,
43 Value srcMemRef,
45 auto extractStridedMetadataOp =
46 memref::ExtractStridedMetadataOp::create(rewriter, loc, srcMemRef);
48 extractStridedMetadataOp.getConstifiedMixedSizes();
51
52 AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
53 AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
54
55 for (auto [srcSize, index] : llvm::zip_equal(srcSizes, mixedIndices)) {
56 finalSizes.push_back(affine::makeComposedFoldedAffineApply(
57 rewriter, loc, s0 - s1, {srcSize, index}));
58 }
59 return finalSizes;
60}
61
62/// Get the sizes needed to create a valid subview for an indexed access.
63/// The trailing dimensions are sized using the accessed shape, taking
64/// the minimum of that shape's size and what's available along the relevant
65/// memref dimension as a dynamic value if the memref is dynamically shaped
66/// (so as to avoid subviews that exceed the bounds of the relevant memref
67/// dimension). If the operation accesses a dynamic number of elements along
68/// the dimension, the size of the subview will always be the remaining element
69/// count along the dimension.
71getIndexedAccessViewSizes(RewriterBase &rewriter,
72 memref::IndexedAccessOpInterface op) {
73 TypedValue<MemRefType> srcMemRef = op.getAccessedMemref();
74 assert(srcMemRef && "expected indexed access with a memref");
75
76 MemRefType srcType = srcMemRef.getType();
77 int64_t srcRank = srcType.getRank();
78 SmallVector<int64_t> accessedShape = op.getAccessedShape();
79 int64_t accessedRank = static_cast<int64_t>(accessedShape.size());
80 assert(accessedRank <= srcRank &&
81 "can't access more dimensions than a memref has");
82
84 int64_t firstAccessedDim = srcRank - accessedRank;
85
86 Location loc = op.getLoc();
87 SmallVector<OpFoldResult> viewSizes(srcRank, rewriter.getIndexAttr(1));
89 AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
90 AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
91 AffineExpr cst = rewriter.getAffineSymbolExpr(2);
92
93 auto ensureSrcSizes = [&]() {
94 if (srcSizes.empty()) {
95 auto extractStridedMetadataOp =
96 memref::ExtractStridedMetadataOp::create(rewriter, loc, srcMemRef);
97 srcSizes = extractStridedMetadataOp.getConstifiedMixedSizes();
98 }
99 };
100
101 for (int64_t accessedDim : llvm::seq<int64_t>(0, accessedRank)) {
102 int64_t accessedSize = accessedShape[accessedDim];
103 int64_t dim = firstAccessedDim + accessedDim;
104 if (!ShapedType::isDynamic(accessedSize)) {
105 int64_t srcDimSize = srcType.getDimSize(dim);
106 if (!ShapedType::isDynamic(srcDimSize) || accessedSize == 1) {
107 viewSizes[dim] = rewriter.getIndexAttr(accessedSize);
108 continue;
109 }
110 ensureSrcSizes();
112 rewriter, loc,
113 AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {s0 - s1, cst},
114 rewriter.getContext()),
115 {srcSizes[dim], indices[dim], rewriter.getIndexAttr(accessedSize)});
116 } else {
117 ensureSrcSizes();
119 rewriter, loc, s0 - s1, {srcSizes[dim], indices[dim]});
120 }
121 }
122 return viewSizes;
123}
124
125static memref::SubViewOp createSubviewForAccess(RewriterBase &rewriter,
126 Location loc, Value srcMemRef,
129 int64_t rank = cast<MemRefType>(srcMemRef.getType()).getRank();
131 SmallVector<OpFoldResult> ones(rank, rewriter.getIndexAttr(1));
132
133 return memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef,
134 /*offsets=*/mixedIndices,
135 /*sizes=*/sizes, /*strides=*/ones);
136}
137
138static SmallVector<Value> getZeroIndices(RewriterBase &rewriter, Location loc,
139 int64_t rank) {
140 if (rank == 0)
141 return {};
142 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
143 return SmallVector<Value>(rank, zero);
144}
145
146/// Rewrite an indexed access op so that all its indices are zeros.
147/// E.g., %res = indexed_access %base[%off0]...[%offN]
148/// =>
149/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
150/// %res = indexed_access %new_base[0,..,0] :
151/// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
152struct IndexedAccessOpRewriter final
153 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
154 using Base::Base;
155
156 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
157 PatternRewriter &rewriter) const override {
158 TypedValue<MemRefType> srcMemRef = op.getAccessedMemref();
159 if (!srcMemRef)
160 return rewriter.notifyMatchFailure(op, "source is not a memref");
161
162 int64_t rank = srcMemRef.getType().getRank();
163 if (rank == 0)
164 return rewriter.notifyMatchFailure(op,
165 "0-D accesses don't need rewriting");
166
167 if (static_cast<int64_t>(op.getAccessedShape().size()) > rank)
168 return rewriter.notifyMatchFailure(
169 op, "can't access more dimensions than a memref has");
170
171 if (!op.hasInboundsIndices())
172 return rewriter.notifyMatchFailure(op, "indices may be out of bounds");
173
174 // If the access already has only zeros as indices there is nothing
175 // to do.
176 if (hasAllZeroIndices(op.getIndices()))
177 return rewriter.notifyMatchFailure(
178 op, "no computation to extract: offsets are 0s");
179
180 SmallVector<OpFoldResult> subviewSizes =
181 getIndexedAccessViewSizes(rewriter, op);
182
183 Location loc = op.getLoc();
184 auto subview = createSubviewForAccess(rewriter, loc, srcMemRef,
185 op.getIndices(), subviewSizes);
186 SmallVector<Value> zeros = getZeroIndices(rewriter, loc, rank);
187
188 std::optional<SmallVector<Value>> newValues =
189 op.updateMemrefAndIndices(rewriter, subview.getResult(), zeros);
190 if (newValues)
191 rewriter.replaceOp(op, *newValues);
192 return success();
193 }
194};
195
196/// Rewrite a vector transfer op so that all its indices are zeros.
197struct TransferOpRewriter final
198 : OpInterfaceRewritePattern<VectorTransferOpInterface> {
199 using Base::Base;
200
201 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
202 PatternRewriter &rewriter) const override {
203 Value srcMemRef = op.getBase();
204 auto srcType = dyn_cast<MemRefType>(srcMemRef.getType());
205 if (!srcType)
206 return rewriter.notifyMatchFailure(op, "source is not a memref");
207
208 int64_t rank = srcType.getRank();
209
210 if (rank == 0)
211 return rewriter.notifyMatchFailure(op,
212 "0-D accesses don't need rewriting");
213
214 if (hasAllZeroIndices(op.getIndices()))
215 return rewriter.notifyMatchFailure(
216 op, "no computation to extract: offsets are 0s");
217
218 Location loc = op.getLoc();
219 SmallVector<OpFoldResult> offsets = getAsOpFoldResult(op.getIndices());
220 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
221 // Approximate sizes needed so we can test the general case of the
222 // replacement we're planning to do - this can be tightened up later when
223 // this pattern is extended to reason about in_bounds, which dimensions are
224 // accessed, etc.
225 SmallVector<OpFoldResult> approximateSizes(
226 rank, rewriter.getIndexAttr(ShapedType::kDynamic));
227 MemRefType subviewType = memref::SubViewOp::inferResultType(
228 srcType, offsets, approximateSizes, strides);
229 if (!subviewType)
230 return rewriter.notifyMatchFailure(op, "failed to infer subview type");
231
232 AffineMap permutationMap = op.getPermutationMap();
233 if (failed(op.mayUpdateStartingPosition(subviewType, permutationMap)))
234 return rewriter.notifyMatchFailure(op,
235 "failed op-specific preconditions");
236
237 SmallVector<OpFoldResult> sizes =
238 getRemainingSizes(rewriter, loc, srcMemRef, op.getIndices());
239 auto subview = createSubviewForAccess(rewriter, loc, srcMemRef,
240 op.getIndices(), sizes);
241 SmallVector<Value> zeros = getZeroIndices(rewriter, loc, rank);
242
243 op.updateStartingPosition(rewriter, subview.getResult(), zeros,
244 AffineMapAttr::get(permutationMap));
245 return success();
246 }
247};
248} // namespace
249
251 RewritePatternSet &patterns) {
252 patterns.add<IndexedAccessOpRewriter, TransferOpRewriter>(
253 patterns.getContext());
254}
return success()
static bool hasAllZeroIndices(LLVM::GEPOp gepOp)
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:373
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from memory access operations such that these ac...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...