MLIR  22.0.0git
MemRefMemorySlot.cpp
Go to the documentation of this file.
1 //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Mem2Reg-related interfaces for MemRef dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/Value.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/ErrorHandling.h"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utilities
29 //===----------------------------------------------------------------------===//
30 
31 /// Walks over the indices of the elements of a tensor of a given `shape` by
32 /// updating `index` in place to the next index. This returns failure if the
33 /// provided index was the last index.
34 static LogicalResult nextIndex(ArrayRef<int64_t> shape,
36  for (size_t i = 0; i < shape.size(); ++i) {
37  index[i]++;
38  if (index[i] < shape[i])
39  return success();
40  index[i] = 0;
41  }
42  return failure();
43 }
44 
45 /// Calls `walker` for each index within a tensor of a given `shape`, providing
46 /// the index as an array attribute of the coordinates.
47 template <typename CallableT>
49  CallableT &&walker) {
50  Type indexType = IndexType::get(ctx);
51  SmallVector<int64_t> shapeIter(shape.size(), 0);
52  do {
53  SmallVector<Attribute> indexAsAttr;
54  for (int64_t dim : shapeIter)
55  indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
56  walker(ArrayAttr::get(ctx, indexAsAttr));
57  } while (succeeded(nextIndex(shape, shapeIter)));
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Interfaces for AllocaOp
62 //===----------------------------------------------------------------------===//
63 
64 static bool isSupportedElementType(Type type) {
65  return llvm::isa<MemRefType>(type) ||
66  OpBuilder(type.getContext()).getZeroAttr(type);
67 }
68 
69 SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
70  MemRefType type = getType();
71  if (!isSupportedElementType(type.getElementType()))
72  return {};
73  if (!type.hasStaticShape())
74  return {};
75  // Make sure the memref contains only a single element.
76  if (type.getNumElements() != 1)
77  return {};
78 
79  return {MemorySlot{getResult(), type.getElementType()}};
80 }
81 
82 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
83  OpBuilder &builder) {
84  assert(isSupportedElementType(slot.elemType));
85  // TODO: support more types.
87  .Case([&](MemRefType t) {
88  return memref::AllocaOp::create(builder, getLoc(), t);
89  })
90  .Default([&](Type t) {
91  return arith::ConstantOp::create(builder, getLoc(), t,
92  builder.getZeroAttr(t));
93  });
94 }
95 
96 std::optional<PromotableAllocationOpInterface>
97 memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
98  Value defaultValue,
99  OpBuilder &builder) {
100  if (defaultValue.use_empty())
101  defaultValue.getDefiningOp()->erase();
102  this->erase();
103  return std::nullopt;
104 }
105 
106 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
107  BlockArgument argument,
108  OpBuilder &builder) {}
109 
111 memref::AllocaOp::getDestructurableSlots() {
112  MemRefType memrefType = getType();
113  auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
114  if (!destructurable)
115  return {};
116 
117  std::optional<DenseMap<Attribute, Type>> destructuredType =
118  destructurable.getSubelementIndexMap();
119  if (!destructuredType)
120  return {};
121 
122  return {
123  DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
124 }
125 
126 DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
127  const DestructurableMemorySlot &slot,
128  const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
130  builder.setInsertionPointAfter(*this);
131 
133 
134  auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
135  for (Attribute usedIndex : usedIndices) {
136  Type elemType = memrefType.getTypeAtIndex(usedIndex);
137  MemRefType elemPtr = MemRefType::get({}, elemType);
138  auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr);
139  newAllocators.push_back(subAlloca);
140  slotMap.try_emplace<MemorySlot>(usedIndex,
141  {subAlloca.getResult(), elemType});
142  }
143 
144  return slotMap;
145 }
146 
147 std::optional<DestructurableAllocationOpInterface>
148 memref::AllocaOp::handleDestructuringComplete(
149  const DestructurableMemorySlot &slot, OpBuilder &builder) {
150  assert(slot.ptr == getResult());
151  this->erase();
152  return std::nullopt;
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // Interfaces for LoadOp/StoreOp
157 //===----------------------------------------------------------------------===//
158 
159 bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
160  return getMemRef() == slot.ptr;
161 }
162 
163 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
164 
165 Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
166  Value reachingDef,
167  const DataLayout &dataLayout) {
168  llvm_unreachable("getStored should not be called on LoadOp");
169 }
170 
171 bool memref::LoadOp::canUsesBeRemoved(
172  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
173  SmallVectorImpl<OpOperand *> &newBlockingUses,
174  const DataLayout &dataLayout) {
175  if (blockingUses.size() != 1)
176  return false;
177  Value blockingUse = (*blockingUses.begin())->get();
178  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
179  getResult().getType() == slot.elemType;
180 }
181 
182 DeletionKind memref::LoadOp::removeBlockingUses(
183  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
184  OpBuilder &builder, Value reachingDefinition,
185  const DataLayout &dataLayout) {
186  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
187  // pointer.
188  getResult().replaceAllUsesWith(reachingDefinition);
189  return DeletionKind::Delete;
190 }
191 
192 /// Returns the index of a memref in attribute form, given its indices. Returns
193 /// a null pointer if whether the indices form a valid index for the provided
194 /// MemRefType cannot be computed. The indices must come from a valid memref
195 /// StoreOp or LoadOp.
197  ValueRange indices,
198  MemRefType memrefType) {
200  for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
201  IntegerAttr coordAttr;
202  if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
203  return {};
204  // MemRefType shape dimensions are always positive (checked by verifier).
205  std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
206  if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
207  return {};
208  index.push_back(coordAttr);
209  }
210  return ArrayAttr::get(ctx, index);
211 }
212 
213 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
214  SmallPtrSetImpl<Attribute> &usedIndices,
215  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
216  const DataLayout &dataLayout) {
217  if (slot.ptr != getMemRef())
218  return false;
221  if (!index)
222  return false;
223  usedIndices.insert(index);
224  return true;
225 }
226 
227 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
229  OpBuilder &builder,
230  const DataLayout &dataLayout) {
233  const MemorySlot &memorySlot = subslots.at(index);
234  setMemRef(memorySlot.ptr);
235  getIndicesMutable().clear();
236  return DeletionKind::Keep;
237 }
238 
239 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
240 
241 bool memref::StoreOp::storesTo(const MemorySlot &slot) {
242  return getMemRef() == slot.ptr;
243 }
244 
245 Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
246  Value reachingDef,
247  const DataLayout &dataLayout) {
248  return getValue();
249 }
250 
251 bool memref::StoreOp::canUsesBeRemoved(
252  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
253  SmallVectorImpl<OpOperand *> &newBlockingUses,
254  const DataLayout &dataLayout) {
255  if (blockingUses.size() != 1)
256  return false;
257  Value blockingUse = (*blockingUses.begin())->get();
258  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
259  getValue() != slot.ptr && getValue().getType() == slot.elemType;
260 }
261 
262 DeletionKind memref::StoreOp::removeBlockingUses(
263  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
264  OpBuilder &builder, Value reachingDefinition,
265  const DataLayout &dataLayout) {
266  return DeletionKind::Delete;
267 }
268 
269 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
270  SmallPtrSetImpl<Attribute> &usedIndices,
271  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
272  const DataLayout &dataLayout) {
273  if (slot.ptr != getMemRef() || getValue() == slot.ptr)
274  return false;
277  if (!index || !slot.subelementTypes.contains(index))
278  return false;
279  usedIndices.insert(index);
280  return true;
281 }
282 
283 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
285  OpBuilder &builder,
286  const DataLayout &dataLayout) {
289  const MemorySlot &memorySlot = subslots.at(index);
290  setMemRef(memorySlot.ptr);
291  getIndicesMutable().clear();
292  return DeletionKind::Keep;
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Interfaces for destructurable types
297 //===----------------------------------------------------------------------===//
298 
299 namespace {
300 
301 struct MemRefDestructurableTypeExternalModel
302  : public DestructurableTypeInterface::ExternalModel<
303  MemRefDestructurableTypeExternalModel, MemRefType> {
304  std::optional<DenseMap<Attribute, Type>>
305  getSubelementIndexMap(Type type) const {
306  auto memrefType = llvm::cast<MemRefType>(type);
307  constexpr int64_t maxMemrefSizeForDestructuring = 16;
308  if (!memrefType.hasStaticShape() ||
309  memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
310  memrefType.getNumElements() == 1)
311  return {};
312 
313  DenseMap<Attribute, Type> destructured;
315  memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
316  destructured.insert({index, memrefType.getElementType()});
317  });
318 
319  return destructured;
320  }
321 
322  Type getTypeAtIndex(Type type, Attribute index) const {
323  auto memrefType = llvm::cast<MemRefType>(type);
324  auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
325  if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
326  return {};
327 
328  Type indexType = IndexType::get(memrefType.getContext());
329  for (const auto &[coordAttr, dimSize] :
330  llvm::zip(coordArrAttr, memrefType.getShape())) {
331  auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
332  if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
333  coord.getInt() >= dimSize)
334  return {};
335  }
336 
337  return memrefType.getElementType();
338  }
339 };
340 
341 } // namespace
342 
343 //===----------------------------------------------------------------------===//
344 // Register external models
345 //===----------------------------------------------------------------------===//
346 
348  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
349  MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
350  });
351 }
static MLIRContext * getContext(OpFoldResult val)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.