MLIR  21.0.0git
MemRefUtils.cpp
Go to the documentation of this file.
1 //===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the MemRef dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "llvm/ADT/STLExtras.h"
20 
21 namespace mlir {
22 namespace memref {
23 
24 bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
25  if (!type.hasStaticShape())
26  return false;
27 
28  SmallVector<int64_t> strides;
29  int64_t offset;
30  if (failed(type.getStridesAndOffset(strides, offset)))
31  return false;
32 
33  // MemRef is contiguous if outer dimensions are size-1 and inner
34  // dimensions have unit strides.
35  int64_t runningStride = 1;
36  int64_t curDim = strides.size() - 1;
37  // Finds all inner dimensions with unit strides.
38  while (curDim >= 0 && strides[curDim] == runningStride) {
39  runningStride *= type.getDimSize(curDim);
40  --curDim;
41  }
42 
43  // Check if other dimensions are size-1.
44  while (curDim >= 0 && type.getDimSize(curDim) == 1) {
45  --curDim;
46  }
47 
48  // All dims are unit-strided or size-1.
49  return curDim < 0;
50 }
51 
52 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
53  OpBuilder &builder, Location loc, int srcBits, int dstBits,
56  unsigned sourceRank = sizes.size();
57  assert(sizes.size() == strides.size() &&
58  "expected as many sizes as strides for a memref");
59  SmallVector<OpFoldResult> indicesVec = llvm::to_vector(indices);
60  if (indices.empty())
61  indicesVec.resize(sourceRank, builder.getIndexAttr(0));
62  assert(indicesVec.size() == strides.size() &&
63  "expected as many indices as rank of memref");
64 
65  // Create the affine symbols and values for linearization.
66  SmallVector<AffineExpr> symbols(2 * sourceRank);
67  bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
68  AffineExpr addMulMap = builder.getAffineConstantExpr(0);
69 
70  SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
71 
72  for (unsigned i = 0; i < sourceRank; ++i) {
73  unsigned offsetIdx = 2 * i;
74  addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
75  offsetValues[offsetIdx] = indicesVec[i];
76  offsetValues[offsetIdx + 1] = strides[i];
77  }
78  // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
79  int64_t scaler = dstBits / srcBits;
81  builder, loc, addMulMap.floorDiv(scaler), offsetValues);
82 
83  size_t symbolIndex = 0;
85  SmallVector<AffineExpr> productExpressions;
86  for (unsigned i = 0; i < sourceRank; ++i) {
87  AffineExpr strideExpr = symbols[symbolIndex++];
88  values.push_back(strides[i]);
89  AffineExpr sizeExpr = symbols[symbolIndex++];
90  values.push_back(sizes[i]);
91 
92  productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
93  }
94  AffineMap maxMap = AffineMap::get(
95  /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
96  builder.getContext());
97  OpFoldResult linearizedSize =
98  affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
99 
100  // Adjust baseOffset by the scale factor (dstBits / srcBits).
101  AffineExpr s0;
102  bindSymbols(builder.getContext(), s0);
104  builder, loc, s0.floorDiv(scaler), {offset});
105 
107  builder, loc, addMulMap % scaler, offsetValues);
108 
109  return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
110  linearizedIndices};
111 }
112 
113 LinearizedMemRefInfo
115  int dstBits, OpFoldResult offset,
116  ArrayRef<OpFoldResult> sizes) {
117  SmallVector<OpFoldResult> strides(sizes.size());
118  if (!sizes.empty()) {
119  strides.back() = builder.getIndexAttr(1);
120  AffineExpr s0, s1;
121  bindSymbols(builder.getContext(), s0, s1);
122  for (int index = sizes.size() - 1; index > 0; --index) {
123  strides[index - 1] = affine::makeComposedFoldedAffineApply(
124  builder, loc, s0 * s1,
125  ArrayRef<OpFoldResult>{strides[index], sizes[index]});
126  }
127  }
128 
129  LinearizedMemRefInfo linearizedMemRefInfo;
130  std::tie(linearizedMemRefInfo, std::ignore) =
131  getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
132  sizes, strides);
133  return linearizedMemRefInfo;
134 }
135 
136 /// Returns true if all the uses of op are not read/load.
137 /// There can be SubviewOp users as long as all its users are also
138 /// StoreOp/transfer_write. If return true it also fills out the uses, if it
139 /// returns false uses is unchanged.
140 static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
141  std::vector<Operation *> opUses;
142  for (OpOperand &use : op->getUses()) {
143  Operation *useOp = use.getOwner();
144  if (isa<memref::DeallocOp>(useOp) ||
145  (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
146  !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
147  (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
148  opUses.push_back(useOp);
149  continue;
150  }
151  return false;
152  }
153  llvm::append_range(uses, opUses);
154  return true;
155 }
156 
157 void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
158  std::vector<Operation *> opToErase;
159  parentOp->walk([&](memref::AllocOp op) {
160  std::vector<Operation *> candidates;
161  if (resultIsNotRead(op, candidates)) {
162  llvm::append_range(opToErase, candidates);
163  opToErase.push_back(op.getOperation());
164  }
165  });
166  for (Operation *op : opToErase)
167  rewriter.eraseOp(op);
168 }
169 
173  OpFoldResult unit) {
174  SmallVector<OpFoldResult> strides(sizes.size(), unit);
175  AffineExpr s0, s1;
176  bindSymbols(builder.getContext(), s0, s1);
177 
178  for (int64_t r = strides.size() - 1; r > 0; --r) {
179  strides[r - 1] = affine::makeComposedFoldedAffineApply(
180  builder, loc, s0 * s1, {strides[r], sizes[r]});
181  }
182  return strides;
183 }
184 
187  ArrayRef<OpFoldResult> sizes) {
188  OpFoldResult unit = builder.getIndexAttr(1);
189  return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
190 }
191 
193  while (auto op = source.getDefiningOp()) {
194  if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
195  subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
196  // A `memref.subview` with an all zero offset, and all unit strides, still
197  // points to the same memory.
198  source = cast<MemrefValue>(subViewOp.getSource());
199  } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
200  // A `memref.cast` still points to the same memory.
201  source = castOp.getSource();
202  } else {
203  return source;
204  }
205  }
206  return source;
207 }
208 
210  while (auto op = source.getDefiningOp()) {
211  if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
212  source = cast<MemrefValue>(viewLike.getViewSource());
213  continue;
214  }
215  return source;
216  }
217  return source;
218 }
219 
220 } // namespace memref
221 } // namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:921
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:204
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1335
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1224
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:24
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Definition: MemRefUtils.h:26
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:330
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50