MLIR  18.0.0git
MemRefMemorySlot.cpp
Go to the documentation of this file.
1 //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Mem2Reg-related interfaces for MemRef dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/Value.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utilities
32 //===----------------------------------------------------------------------===//
33 
34 /// Walks over the indices of the elements of a tensor of a given `shape` by
35 /// updating `index` in place to the next index. This returns failure if the
36 /// provided index was the last index.
39  for (size_t i = 0; i < shape.size(); ++i) {
40  index[i]++;
41  if (index[i] < shape[i])
42  return success();
43  index[i] = 0;
44  }
45  return failure();
46 }
47 
48 /// Calls `walker` for each index within a tensor of a given `shape`, providing
49 /// the index as an array attribute of the coordinates.
50 template <typename CallableT>
52  CallableT &&walker) {
53  Type indexType = IndexType::get(ctx);
54  SmallVector<int64_t> shapeIter(shape.size(), 0);
55  do {
56  SmallVector<Attribute> indexAsAttr;
57  for (int64_t dim : shapeIter)
58  indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
59  walker(ArrayAttr::get(ctx, indexAsAttr));
60  } while (succeeded(nextIndex(shape, shapeIter)));
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Interfaces for AllocaOp
65 //===----------------------------------------------------------------------===//
66 
67 static bool isSupportedElementType(Type type) {
68  return llvm::isa<MemRefType>(type) ||
69  OpBuilder(type.getContext()).getZeroAttr(type);
70 }
71 
72 SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
73  MemRefType type = getType();
74  if (!isSupportedElementType(type.getElementType()))
75  return {};
76  if (!type.hasStaticShape())
77  return {};
78  // Make sure the memref contains only a single element.
79  if (type.getNumElements() != 1)
80  return {};
81 
82  return {MemorySlot{getResult(), type.getElementType()}};
83 }
84 
85 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
86  RewriterBase &rewriter) {
87  assert(isSupportedElementType(slot.elemType));
88  // TODO: support more types.
90  .Case([&](MemRefType t) {
91  return rewriter.create<memref::AllocaOp>(getLoc(), t);
92  })
93  .Default([&](Type t) {
94  return rewriter.create<arith::ConstantOp>(getLoc(), t,
95  rewriter.getZeroAttr(t));
96  });
97 }
98 
99 void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100  Value defaultValue,
101  RewriterBase &rewriter) {
102  if (defaultValue.use_empty())
103  rewriter.eraseOp(defaultValue.getDefiningOp());
104  rewriter.eraseOp(*this);
105 }
106 
107 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
108  BlockArgument argument,
109  RewriterBase &rewriter) {}
110 
112 memref::AllocaOp::getDestructurableSlots() {
113  MemRefType memrefType = getType();
114  auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
115  if (!destructurable)
116  return {};
117 
118  std::optional<DenseMap<Attribute, Type>> destructuredType =
119  destructurable.getSubelementIndexMap();
120  if (!destructuredType)
121  return {};
122 
123  DenseMap<Attribute, Type> indexMap;
124  for (auto const &[index, type] : *destructuredType)
125  indexMap.insert({index, MemRefType::get({}, type)});
126 
127  return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
128 }
129 
131 memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
132  const SmallPtrSetImpl<Attribute> &usedIndices,
133  RewriterBase &rewriter) {
134  rewriter.setInsertionPointAfter(*this);
135 
137 
138  auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
139  for (Attribute usedIndex : usedIndices) {
140  Type elemType = memrefType.getTypeAtIndex(usedIndex);
141  MemRefType elemPtr = MemRefType::get({}, elemType);
142  auto subAlloca = rewriter.create<memref::AllocaOp>(getLoc(), elemPtr);
143  slotMap.try_emplace<MemorySlot>(usedIndex,
144  {subAlloca.getResult(), elemType});
145  }
146 
147  return slotMap;
148 }
149 
150 void memref::AllocaOp::handleDestructuringComplete(
151  const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
152  assert(slot.ptr == getResult());
153  rewriter.eraseOp(*this);
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Interfaces for LoadOp/StoreOp
158 //===----------------------------------------------------------------------===//
159 
160 bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
161  return getMemRef() == slot.ptr;
162 }
163 
164 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
165 
166 Value memref::LoadOp::getStored(const MemorySlot &slot,
167  RewriterBase &rewriter) {
168  llvm_unreachable("getStored should not be called on LoadOp");
169 }
170 
171 bool memref::LoadOp::canUsesBeRemoved(
172  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
173  SmallVectorImpl<OpOperand *> &newBlockingUses) {
174  if (blockingUses.size() != 1)
175  return false;
176  Value blockingUse = (*blockingUses.begin())->get();
177  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
178  getResult().getType() == slot.elemType;
179 }
180 
181 DeletionKind memref::LoadOp::removeBlockingUses(
182  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
183  RewriterBase &rewriter, Value reachingDefinition) {
184  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
185  // pointer.
186  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
187  return DeletionKind::Delete;
188 }
189 
190 /// Returns the index of a memref in attribute form, given its indices. Returns
191 /// a null pointer if whether the indices form a valid index for the provided
192 /// MemRefType cannot be computed. The indices must come from a valid memref
193 /// StoreOp or LoadOp.
195  ValueRange indices,
196  MemRefType memrefType) {
198  for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
199  IntegerAttr coordAttr;
200  if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
201  return {};
202  // MemRefType shape dimensions are always positive (checked by verifier).
203  std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
204  if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
205  return {};
206  index.push_back(coordAttr);
207  }
208  return ArrayAttr::get(ctx, index);
209 }
210 
211 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
212  SmallPtrSetImpl<Attribute> &usedIndices,
213  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
214  if (slot.ptr != getMemRef())
215  return false;
218  if (!index)
219  return false;
220  usedIndices.insert(index);
221  return true;
222 }
223 
224 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
226  RewriterBase &rewriter) {
229  const MemorySlot &memorySlot = subslots.at(index);
230  rewriter.updateRootInPlace(*this, [&]() {
231  setMemRef(memorySlot.ptr);
232  getIndicesMutable().clear();
233  });
234  return DeletionKind::Keep;
235 }
236 
237 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
238 
239 bool memref::StoreOp::storesTo(const MemorySlot &slot) {
240  return getMemRef() == slot.ptr;
241 }
242 
243 Value memref::StoreOp::getStored(const MemorySlot &slot,
244  RewriterBase &rewriter) {
245  return getValue();
246 }
247 
248 bool memref::StoreOp::canUsesBeRemoved(
249  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
250  SmallVectorImpl<OpOperand *> &newBlockingUses) {
251  if (blockingUses.size() != 1)
252  return false;
253  Value blockingUse = (*blockingUses.begin())->get();
254  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
255  getValue() != slot.ptr && getValue().getType() == slot.elemType;
256 }
257 
258 DeletionKind memref::StoreOp::removeBlockingUses(
259  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
260  RewriterBase &rewriter, Value reachingDefinition) {
261  return DeletionKind::Delete;
262 }
263 
264 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
265  SmallPtrSetImpl<Attribute> &usedIndices,
266  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
267  if (slot.ptr != getMemRef() || getValue() == slot.ptr)
268  return false;
271  if (!index || !slot.elementPtrs.contains(index))
272  return false;
273  usedIndices.insert(index);
274  return true;
275 }
276 
277 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
279  RewriterBase &rewriter) {
282  const MemorySlot &memorySlot = subslots.at(index);
283  rewriter.updateRootInPlace(*this, [&]() {
284  setMemRef(memorySlot.ptr);
285  getIndicesMutable().clear();
286  });
287  return DeletionKind::Keep;
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // Interfaces for destructurable types
292 //===----------------------------------------------------------------------===//
293 
294 namespace {
295 
296 struct MemRefDestructurableTypeExternalModel
297  : public DestructurableTypeInterface::ExternalModel<
298  MemRefDestructurableTypeExternalModel, MemRefType> {
299  std::optional<DenseMap<Attribute, Type>>
300  getSubelementIndexMap(Type type) const {
301  auto memrefType = llvm::cast<MemRefType>(type);
302  constexpr int64_t maxMemrefSizeForDestructuring = 16;
303  if (!memrefType.hasStaticShape() ||
304  memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
305  memrefType.getNumElements() == 1)
306  return {};
307 
308  DenseMap<Attribute, Type> destructured;
310  memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
311  destructured.insert({index, memrefType.getElementType()});
312  });
313 
314  return destructured;
315  }
316 
317  Type getTypeAtIndex(Type type, Attribute index) const {
318  auto memrefType = llvm::cast<MemRefType>(type);
319  auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
320  if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
321  return {};
322 
323  Type indexType = IndexType::get(memrefType.getContext());
324  for (const auto &[coordAttr, dimSize] :
325  llvm::zip(coordArrAttr, memrefType.getShape())) {
326  auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
327  if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
328  coord.getInt() >= dimSize)
329  return {};
330  }
331 
332  return memrefType.getElementType();
333  }
334 };
335 
336 } // namespace
337 
338 //===----------------------------------------------------------------------===//
339 // Register external models
340 //===----------------------------------------------------------------------===//
341 
343  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
344  MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
345  });
346 }
static MLIRContext * getContext(OpFoldResult val)
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the type of the pointer that will be generated to access the ...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.