MLIR 23.0.0git
MemRefUtils.cpp
Go to the documentation of this file.
1//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the MemRef dialect.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/STLExtras.h"
19
20namespace mlir {
21namespace memref {
22
24 if (!type.hasStaticShape())
25 return false;
26
27 int64_t rank = type.getRank();
28 if (rank == 0)
29 return true;
30
32 int64_t offset;
33 if (failed(type.getStridesAndOffset(strides, offset)))
34 return false;
35
36 // MemRef is contiguous if outer dimensions are size-1 and inner
37 // dimensions have unit strides.
38 int64_t runningStride = 1;
39 int64_t curDim = rank - 1;
40 // Finds all inner dimensions with unit strides.
41 while (curDim >= 0 && strides[curDim] == runningStride) {
42 runningStride *= type.getDimSize(curDim);
43 --curDim;
44 }
45
46 // Check if other dimensions are size-1.
47 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
48 --curDim;
49 }
50
51 // All dims are unit-strided or size-1.
52 return curDim < 0;
53}
54
55std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
56 OpBuilder &builder, Location loc, int srcBits, int dstBits,
59 LinearizedDivKind sizeDivKind) {
60 unsigned sourceRank = sizes.size();
61 assert(sizes.size() == strides.size() &&
62 "expected as many sizes as strides for a memref");
63 SmallVector<OpFoldResult> indicesVec = llvm::to_vector(indices);
64 if (indices.empty())
65 indicesVec.resize(sourceRank, builder.getIndexAttr(0));
66 assert(indicesVec.size() == strides.size() &&
67 "expected as many indices as rank of memref");
68
69 // Create the affine symbols and values for linearization.
70 SmallVector<AffineExpr> symbols(2 * sourceRank);
71 bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
72 AffineExpr addMulMap = builder.getAffineConstantExpr(0);
73
74 SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
75
76 for (unsigned i = 0; i < sourceRank; ++i) {
77 unsigned offsetIdx = 2 * i;
78 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
79 offsetValues[offsetIdx] = indicesVec[i];
80 offsetValues[offsetIdx + 1] = strides[i];
81 }
82 // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
83 int64_t scaler = dstBits / srcBits;
85 builder, loc, addMulMap.floorDiv(scaler), offsetValues);
86
87 size_t symbolIndex = 0;
89 SmallVector<AffineExpr> productExpressions;
90 for (unsigned i = 0; i < sourceRank; ++i) {
91 AffineExpr strideExpr = symbols[symbolIndex++];
92 values.push_back(strides[i]);
93 AffineExpr sizeExpr = symbols[symbolIndex++];
94 values.push_back(sizes[i]);
95
96 AffineExpr product = strideExpr * sizeExpr;
97 productExpressions.push_back(sizeDivKind == LinearizedDivKind::Ceil
98 ? product.ceilDiv(scaler)
99 : product.floorDiv(scaler));
100 }
101 AffineMap maxMap = AffineMap::get(
102 /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
103 builder.getContext());
104 OpFoldResult linearizedSize =
105 affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
106
107 // Adjust baseOffset by the scale factor (dstBits / srcBits).
108 AffineExpr s0;
109 bindSymbols(builder.getContext(), s0);
111 builder, loc, s0.floorDiv(scaler), {offset});
112
114 builder, loc, addMulMap % scaler, offsetValues);
115
116 return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
117 linearizedIndices};
118}
119
120LinearizedMemRefInfo
122 int dstBits, OpFoldResult offset,
124 LinearizedDivKind sizeDivKind) {
125 SmallVector<OpFoldResult> strides(sizes.size());
126 if (!sizes.empty()) {
127 strides.back() = builder.getIndexAttr(1);
128 AffineExpr s0, s1;
129 bindSymbols(builder.getContext(), s0, s1);
130 for (int64_t index = static_cast<int64_t>(sizes.size()) - 1; index > 0;
131 --index) {
133 builder, loc, s0 * s1,
134 ArrayRef<OpFoldResult>{strides[index], sizes[index]});
135 }
136 }
137
138 LinearizedMemRefInfo linearizedMemRefInfo;
139 std::tie(linearizedMemRefInfo, std::ignore) =
140 getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
141 sizes, strides, /*indices=*/{},
142 sizeDivKind);
143 return linearizedMemRefInfo;
144}
145
146/// Returns true if all the uses of op are not read/load.
147/// There can be view-like-op users as long as all its users are also
148/// StoreOp/transfer_write. If return true it also fills out the uses, if it
149/// returns false uses is unchanged.
150static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
151 std::vector<Operation *> opUses;
152 for (OpOperand &use : op->getUses()) {
153 Operation *useOp = use.getOwner();
154 // Use escaped the scope
156 return false;
157 if (isa<memref::DeallocOp>(useOp) ||
158 (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
160 (isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) {
161 opUses.push_back(useOp);
162 continue;
163 }
164 return false;
165 }
166 llvm::append_range(uses, opUses);
167 return true;
168}
169
171 std::vector<Operation *> opToErase;
172 parentOp->walk([&](Operation *op) {
173 std::vector<Operation *> candidates;
174 if (isa<memref::AllocOp, memref::AllocaOp>(op) &&
175 resultIsNotRead(op, candidates)) {
176 llvm::append_range(opToErase, candidates);
177 opToErase.push_back(op);
178 }
179 });
180
181 for (Operation *op : opToErase)
182 rewriter.eraseOp(op);
183}
184
188 OpFoldResult unit) {
189 SmallVector<OpFoldResult> strides(sizes.size(), unit);
190 AffineExpr s0, s1;
191 bindSymbols(builder.getContext(), s0, s1);
192
193 for (int64_t r = static_cast<int64_t>(strides.size()) - 1; r > 0; --r) {
195 builder, loc, s0 * s1, {strides[r], sizes[r]});
196 }
197 return strides;
198}
199
203 OpFoldResult unit = builder.getIndexAttr(1);
204 return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
205}
206
208 while (auto *op = source.getDefiningOp()) {
209 if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
210 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
211 // A `memref.subview` with an all zero offset, and all unit strides, still
212 // points to the same memory.
213 source = cast<MemrefValue>(subViewOp.getSource());
214 } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
215 // A `memref.cast` still points to the same memory.
216 source = castOp.getSource();
217 } else {
218 return source;
219 }
220 }
221 return source;
222}
223
225 while (auto *op = source.getDefiningOp()) {
226 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
227 if (source == viewLike.getViewDest()) {
228 source = cast<MemrefValue>(viewLike.getViewSource());
229 continue;
230 }
231 }
232 return source;
233 }
234 return source;
235}
236
238 memref::ExpandShapeOp expandShapeOp,
240 SmallVectorImpl<Value> &sourceIndices,
241 bool startsInbounds) {
242 SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
243
244 // Traverse all reassociation groups to determine the appropriate indices
245 // corresponding to each one of them post op folding.
246 for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
247 assert(!group.empty() && "association indices groups cannot be empty");
248 int64_t groupSize = group.size();
249 if (groupSize == 1) {
250 sourceIndices.push_back(indices[group[0]]);
251 continue;
252 }
253 SmallVector<OpFoldResult> groupBasis =
254 llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
255 SmallVector<Value> groupIndices =
256 llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
257 Value collapsedIndex = affine::AffineLinearizeIndexOp::create(
258 rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
259 sourceIndices.push_back(collapsedIndex);
260 }
261}
262
264 memref::CollapseShapeOp collapseShapeOp,
266 SmallVectorImpl<Value> &sourceIndices,
267 bool startsInbounds) {
268 // Note: collapse_shape requires a strided memref, we can do this.
269 auto metadata = memref::ExtractStridedMetadataOp::create(
270 rewriter, loc, collapseShapeOp.getSrc());
271 SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
272 for (auto [index, group] :
273 llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
274 assert(!group.empty() && "association indices groups cannot be empty");
275 int64_t groupSize = group.size();
276
277 if (groupSize == 1) {
278 sourceIndices.push_back(index);
279 continue;
280 }
281
282 // If we don't know that this value is in-bounds, the largest return value
283 // of the delinearization may exceed `sourceSizes[d]`, so we drop that first
284 // group entry in order to maintain soundness.
285 auto trimmedGroup =
286 ArrayRef<int64_t>(group).drop_front(startsInbounds ? 0 : 1);
287 SmallVector<OpFoldResult> basis = llvm::map_to_vector(
288 trimmedGroup, [&](int64_t d) { return sourceSizes[d]; });
289 auto delinearize = affine::AffineDelinearizeIndexOp::create(
290 rewriter, loc, index, basis, /*hasOuterBound=*/startsInbounds);
291 llvm::append_range(sourceIndices, delinearize.getResults());
292 }
293 if (collapseShapeOp.getReassociationIndices().empty()) {
294 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
295 int64_t srcRank =
296 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
298 rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
299 for (int64_t i = 0; i < srcRank; i++) {
300 sourceIndices.push_back(
301 getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
302 }
303 }
304}
305
307 Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices,
308 SmallVectorImpl<Value> &sourceIndices) {
309 if (!subViewOp.hasZeroOffset() || !subViewOp.hasUnitStride())
310 return failure();
311
312 MemRefType srcType = subViewOp.getSourceType();
313 MemRefType resType = subViewOp.getType();
314 unsigned srcRank = srcType.getRank();
315 unsigned resRank = resType.getRank();
316 if (srcRank <= resRank || indices.size() != resRank)
317 return failure();
318
319 auto droppedDims = subViewOp.getDroppedDims();
320 if (droppedDims.none() || droppedDims.count() != srcRank - resRank)
321 return failure();
322
323 auto mixedSizes = subViewOp.getMixedSizes();
324 if (mixedSizes.size() != srcRank)
325 return failure();
326
327 unsigned resultDim = 0;
328 for (unsigned sourceDim = 0; sourceDim < srcRank; ++sourceDim) {
329 if (droppedDims.test(sourceDim)) {
330 auto sizeCst = getConstantIntValue(mixedSizes[sourceDim]);
331 if (!sizeCst || *sizeCst != 1)
332 return failure();
333 sourceIndices.push_back(
334 getValueOrCreateConstantIndexOp(b, loc, b.getIndexAttr(0)));
335 continue;
336 }
337 if (resultDim >= indices.size())
338 return failure();
339 sourceIndices.push_back(indices[resultDim++]);
340 }
341 if (resultDim != indices.size())
342 return failure();
343
344 return success();
345}
346
347} // namespace memref
348} // namespace mlir
return success()
static int64_t product(ArrayRef< int64_t > vals)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:377
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:383
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:782
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition Operation.h:871
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LinearizedDivKind
Controls how the per-dimension contribution to linearizedSize is divided by dstBits / srcBits when sc...
Definition MemRefUtils.h:44
LogicalResult resolveSourceIndicesRankReducingSubview(Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a rank-reducing full su...
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Definition MemRefUtils.h:26
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition AffineExpr.h:330
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition MemRefUtils.h:64