MLIR 22.0.0git
MemRefUtils.cpp
Go to the documentation of this file.
1//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the MemRef dialect.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/STLExtras.h"
19
20namespace mlir {
21namespace memref {
22
24 if (!type.hasStaticShape())
25 return false;
26
28 int64_t offset;
29 if (failed(type.getStridesAndOffset(strides, offset)))
30 return false;
31
32 // MemRef is contiguous if outer dimensions are size-1 and inner
33 // dimensions have unit strides.
34 int64_t runningStride = 1;
35 int64_t curDim = strides.size() - 1;
36 // Finds all inner dimensions with unit strides.
37 while (curDim >= 0 && strides[curDim] == runningStride) {
38 runningStride *= type.getDimSize(curDim);
39 --curDim;
40 }
41
42 // Check if other dimensions are size-1.
43 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
44 --curDim;
45 }
46
47 // All dims are unit-strided or size-1.
48 return curDim < 0;
49}
50
51std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
52 OpBuilder &builder, Location loc, int srcBits, int dstBits,
55 unsigned sourceRank = sizes.size();
56 assert(sizes.size() == strides.size() &&
57 "expected as many sizes as strides for a memref");
58 SmallVector<OpFoldResult> indicesVec = llvm::to_vector(indices);
59 if (indices.empty())
60 indicesVec.resize(sourceRank, builder.getIndexAttr(0));
61 assert(indicesVec.size() == strides.size() &&
62 "expected as many indices as rank of memref");
63
64 // Create the affine symbols and values for linearization.
65 SmallVector<AffineExpr> symbols(2 * sourceRank);
66 bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
67 AffineExpr addMulMap = builder.getAffineConstantExpr(0);
68
69 SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
70
71 for (unsigned i = 0; i < sourceRank; ++i) {
72 unsigned offsetIdx = 2 * i;
73 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
74 offsetValues[offsetIdx] = indicesVec[i];
75 offsetValues[offsetIdx + 1] = strides[i];
76 }
77 // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
78 int64_t scaler = dstBits / srcBits;
80 builder, loc, addMulMap.floorDiv(scaler), offsetValues);
81
82 size_t symbolIndex = 0;
84 SmallVector<AffineExpr> productExpressions;
85 for (unsigned i = 0; i < sourceRank; ++i) {
86 AffineExpr strideExpr = symbols[symbolIndex++];
87 values.push_back(strides[i]);
88 AffineExpr sizeExpr = symbols[symbolIndex++];
89 values.push_back(sizes[i]);
90
91 productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
92 }
94 /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
95 builder.getContext());
96 OpFoldResult linearizedSize =
97 affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
98
99 // Adjust baseOffset by the scale factor (dstBits / srcBits).
100 AffineExpr s0;
101 bindSymbols(builder.getContext(), s0);
103 builder, loc, s0.floorDiv(scaler), {offset});
104
106 builder, loc, addMulMap % scaler, offsetValues);
107
108 return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
109 linearizedIndices};
110}
111
112LinearizedMemRefInfo
114 int dstBits, OpFoldResult offset,
116 SmallVector<OpFoldResult> strides(sizes.size());
117 if (!sizes.empty()) {
118 strides.back() = builder.getIndexAttr(1);
119 AffineExpr s0, s1;
120 bindSymbols(builder.getContext(), s0, s1);
121 for (int index = sizes.size() - 1; index > 0; --index) {
123 builder, loc, s0 * s1,
124 ArrayRef<OpFoldResult>{strides[index], sizes[index]});
125 }
126 }
127
128 LinearizedMemRefInfo linearizedMemRefInfo;
129 std::tie(linearizedMemRefInfo, std::ignore) =
130 getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
131 sizes, strides);
132 return linearizedMemRefInfo;
133}
134
135/// Returns true if all the uses of op are not read/load.
136/// There can be view-like-op users as long as all its users are also
137/// StoreOp/transfer_write. If return true it also fills out the uses, if it
138/// returns false uses is unchanged.
139static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
140 std::vector<Operation *> opUses;
141 for (OpOperand &use : op->getUses()) {
142 Operation *useOp = use.getOwner();
143 // Use escaped the scope
145 return false;
146 if (isa<memref::DeallocOp>(useOp) ||
147 (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
149 (isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) {
150 opUses.push_back(useOp);
151 continue;
152 }
153 return false;
154 }
155 llvm::append_range(uses, opUses);
156 return true;
157}
158
160 std::vector<Operation *> opToErase;
161 parentOp->walk([&](Operation *op) {
162 std::vector<Operation *> candidates;
163 if (isa<memref::AllocOp, memref::AllocaOp>(op) &&
164 resultIsNotRead(op, candidates)) {
165 llvm::append_range(opToErase, candidates);
166 opToErase.push_back(op);
167 }
168 });
169
170 for (Operation *op : opToErase)
171 rewriter.eraseOp(op);
172}
173
177 OpFoldResult unit) {
178 SmallVector<OpFoldResult> strides(sizes.size(), unit);
179 AffineExpr s0, s1;
180 bindSymbols(builder.getContext(), s0, s1);
181
182 for (int64_t r = strides.size() - 1; r > 0; --r) {
184 builder, loc, s0 * s1, {strides[r], sizes[r]});
185 }
186 return strides;
187}
188
192 OpFoldResult unit = builder.getIndexAttr(1);
193 return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
194}
195
197 while (auto *op = source.getDefiningOp()) {
198 if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
199 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
200 // A `memref.subview` with an all zero offset, and all unit strides, still
201 // points to the same memory.
202 source = cast<MemrefValue>(subViewOp.getSource());
203 } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
204 // A `memref.cast` still points to the same memory.
205 source = castOp.getSource();
206 } else {
207 return source;
208 }
209 }
210 return source;
211}
212
214 while (auto *op = source.getDefiningOp()) {
215 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
216 if (source == viewLike.getViewDest()) {
217 source = cast<MemrefValue>(viewLike.getViewSource());
218 continue;
219 }
220 }
221 return source;
222 }
223 return source;
224}
225
227 Location loc, PatternRewriter &rewriter,
228 memref::ExpandShapeOp expandShapeOp, ValueRange indices,
229 SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
230 SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
231
232 // Traverse all reassociation groups to determine the appropriate indices
233 // corresponding to each one of them post op folding.
234 for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
235 assert(!group.empty() && "association indices groups cannot be empty");
236 int64_t groupSize = group.size();
237 if (groupSize == 1) {
238 sourceIndices.push_back(indices[group[0]]);
239 continue;
240 }
241 SmallVector<OpFoldResult> groupBasis =
242 llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
243 SmallVector<Value> groupIndices =
244 llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
245 Value collapsedIndex = affine::AffineLinearizeIndexOp::create(
246 rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
247 sourceIndices.push_back(collapsedIndex);
248 }
249 return success();
250}
251
252LogicalResult
254 memref::CollapseShapeOp collapseShapeOp,
256 SmallVectorImpl<Value> &sourceIndices) {
257 // Note: collapse_shape requires a strided memref, we can do this.
258 auto metadata = memref::ExtractStridedMetadataOp::create(
259 rewriter, loc, collapseShapeOp.getSrc());
260 SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
261 for (auto [index, group] :
262 llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
263 assert(!group.empty() && "association indices groups cannot be empty");
264 int64_t groupSize = group.size();
265
266 if (groupSize == 1) {
267 sourceIndices.push_back(index);
268 continue;
269 }
270
272 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
273 auto delinearize = affine::AffineDelinearizeIndexOp::create(
274 rewriter, loc, index, basis, /*hasOuterBound=*/true);
275 llvm::append_range(sourceIndices, delinearize.getResults());
276 }
277 if (collapseShapeOp.getReassociationIndices().empty()) {
278 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
279 int64_t srcRank =
280 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
282 rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
283 for (int64_t i = 0; i < srcRank; i++) {
284 sourceIndices.push_back(
285 getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
286 }
287 }
288 return success();
289}
290
291} // namespace memref
292} // namespace mlir
return success()
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:378
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:757
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition Operation.h:846
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Definition MemRefUtils.h:26
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition AffineExpr.h:330
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition MemRefUtils.h:50