MLIR  21.0.0git
MemRefUtils.cpp
Go to the documentation of this file.
1 //===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the MemRef dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/ADT/STLExtras.h"
18 
19 namespace mlir {
20 namespace memref {
21 
22 bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
23  if (!type.hasStaticShape())
24  return false;
25 
26  SmallVector<int64_t> strides;
27  int64_t offset;
28  if (failed(type.getStridesAndOffset(strides, offset)))
29  return false;
30 
31  // MemRef is contiguous if outer dimensions are size-1 and inner
32  // dimensions have unit strides.
33  int64_t runningStride = 1;
34  int64_t curDim = strides.size() - 1;
35  // Finds all inner dimensions with unit strides.
36  while (curDim >= 0 && strides[curDim] == runningStride) {
37  runningStride *= type.getDimSize(curDim);
38  --curDim;
39  }
40 
41  // Check if other dimensions are size-1.
42  while (curDim >= 0 && type.getDimSize(curDim) == 1) {
43  --curDim;
44  }
45 
46  // All dims are unit-strided or size-1.
47  return curDim < 0;
48 }
49 
50 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
51  OpBuilder &builder, Location loc, int srcBits, int dstBits,
54  unsigned sourceRank = sizes.size();
55  assert(sizes.size() == strides.size() &&
56  "expected as many sizes as strides for a memref");
57  SmallVector<OpFoldResult> indicesVec = llvm::to_vector(indices);
58  if (indices.empty())
59  indicesVec.resize(sourceRank, builder.getIndexAttr(0));
60  assert(indicesVec.size() == strides.size() &&
61  "expected as many indices as rank of memref");
62 
63  // Create the affine symbols and values for linearization.
64  SmallVector<AffineExpr> symbols(2 * sourceRank);
65  bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
66  AffineExpr addMulMap = builder.getAffineConstantExpr(0);
67 
68  SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
69 
70  for (unsigned i = 0; i < sourceRank; ++i) {
71  unsigned offsetIdx = 2 * i;
72  addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
73  offsetValues[offsetIdx] = indicesVec[i];
74  offsetValues[offsetIdx + 1] = strides[i];
75  }
76  // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
77  int64_t scaler = dstBits / srcBits;
79  builder, loc, addMulMap.floorDiv(scaler), offsetValues);
80 
81  size_t symbolIndex = 0;
83  SmallVector<AffineExpr> productExpressions;
84  for (unsigned i = 0; i < sourceRank; ++i) {
85  AffineExpr strideExpr = symbols[symbolIndex++];
86  values.push_back(strides[i]);
87  AffineExpr sizeExpr = symbols[symbolIndex++];
88  values.push_back(sizes[i]);
89 
90  productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
91  }
92  AffineMap maxMap = AffineMap::get(
93  /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
94  builder.getContext());
95  OpFoldResult linearizedSize =
96  affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
97 
98  // Adjust baseOffset by the scale factor (dstBits / srcBits).
99  AffineExpr s0;
100  bindSymbols(builder.getContext(), s0);
102  builder, loc, s0.floorDiv(scaler), {offset});
103 
105  builder, loc, addMulMap % scaler, offsetValues);
106 
107  return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
108  linearizedIndices};
109 }
110 
111 LinearizedMemRefInfo
113  int dstBits, OpFoldResult offset,
114  ArrayRef<OpFoldResult> sizes) {
115  SmallVector<OpFoldResult> strides(sizes.size());
116  if (!sizes.empty()) {
117  strides.back() = builder.getIndexAttr(1);
118  AffineExpr s0, s1;
119  bindSymbols(builder.getContext(), s0, s1);
120  for (int index = sizes.size() - 1; index > 0; --index) {
121  strides[index - 1] = affine::makeComposedFoldedAffineApply(
122  builder, loc, s0 * s1,
123  ArrayRef<OpFoldResult>{strides[index], sizes[index]});
124  }
125  }
126 
127  LinearizedMemRefInfo linearizedMemRefInfo;
128  std::tie(linearizedMemRefInfo, std::ignore) =
129  getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
130  sizes, strides);
131  return linearizedMemRefInfo;
132 }
133 
134 /// Returns true if all the uses of op are not read/load.
135 /// There can be SubviewOp users as long as all its users are also
136 /// StoreOp/transfer_write. If return true it also fills out the uses, if it
137 /// returns false uses is unchanged.
138 static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
139  std::vector<Operation *> opUses;
140  for (OpOperand &use : op->getUses()) {
141  Operation *useOp = use.getOwner();
142  if (isa<memref::DeallocOp>(useOp) ||
143  (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
144  !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
145  (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
146  opUses.push_back(useOp);
147  continue;
148  }
149  return false;
150  }
151  llvm::append_range(uses, opUses);
152  return true;
153 }
154 
155 void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
156  std::vector<Operation *> opToErase;
157  parentOp->walk([&](Operation *op) {
158  std::vector<Operation *> candidates;
159  if (isa<memref::AllocOp, memref::AllocaOp>(op) &&
160  resultIsNotRead(op, candidates)) {
161  llvm::append_range(opToErase, candidates);
162  opToErase.push_back(op);
163  }
164  });
165 
166  for (Operation *op : opToErase)
167  rewriter.eraseOp(op);
168 }
169 
173  OpFoldResult unit) {
174  SmallVector<OpFoldResult> strides(sizes.size(), unit);
175  AffineExpr s0, s1;
176  bindSymbols(builder.getContext(), s0, s1);
177 
178  for (int64_t r = strides.size() - 1; r > 0; --r) {
179  strides[r - 1] = affine::makeComposedFoldedAffineApply(
180  builder, loc, s0 * s1, {strides[r], sizes[r]});
181  }
182  return strides;
183 }
184 
187  ArrayRef<OpFoldResult> sizes) {
188  OpFoldResult unit = builder.getIndexAttr(1);
189  return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
190 }
191 
193  while (auto op = source.getDefiningOp()) {
194  if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
195  subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
196  // A `memref.subview` with an all zero offset, and all unit strides, still
197  // points to the same memory.
198  source = cast<MemrefValue>(subViewOp.getSource());
199  } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
200  // A `memref.cast` still points to the same memory.
201  source = castOp.getSource();
202  } else {
203  return source;
204  }
205  }
206  return source;
207 }
208 
210  while (auto op = source.getDefiningOp()) {
211  if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
212  source = cast<MemrefValue>(viewLike.getViewSource());
213  continue;
214  }
215  return source;
216  }
217  return source;
218 }
219 
220 } // namespace memref
221 } // namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:961
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1441
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:50
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:22
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Definition: MemRefUtils.h:26
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:330
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50