MLIR  20.0.0git
MemRefMemorySlot.cpp
Go to the documentation of this file.
1 //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Mem2Reg-related interfaces for MemRef dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Utilities
31 //===----------------------------------------------------------------------===//
32 
33 /// Walks over the indices of the elements of a tensor of a given `shape` by
34 /// updating `index` in place to the next index. This returns failure if the
35 /// provided index was the last index.
36 static LogicalResult nextIndex(ArrayRef<int64_t> shape,
38  for (size_t i = 0; i < shape.size(); ++i) {
39  index[i]++;
40  if (index[i] < shape[i])
41  return success();
42  index[i] = 0;
43  }
44  return failure();
45 }
46 
47 /// Calls `walker` for each index within a tensor of a given `shape`, providing
48 /// the index as an array attribute of the coordinates.
49 template <typename CallableT>
51  CallableT &&walker) {
52  Type indexType = IndexType::get(ctx);
53  SmallVector<int64_t> shapeIter(shape.size(), 0);
54  do {
55  SmallVector<Attribute> indexAsAttr;
56  for (int64_t dim : shapeIter)
57  indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
58  walker(ArrayAttr::get(ctx, indexAsAttr));
59  } while (succeeded(nextIndex(shape, shapeIter)));
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Interfaces for AllocaOp
64 //===----------------------------------------------------------------------===//
65 
66 static bool isSupportedElementType(Type type) {
67  return llvm::isa<MemRefType>(type) ||
68  OpBuilder(type.getContext()).getZeroAttr(type);
69 }
70 
71 SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
72  MemRefType type = getType();
73  if (!isSupportedElementType(type.getElementType()))
74  return {};
75  if (!type.hasStaticShape())
76  return {};
77  // Make sure the memref contains only a single element.
78  if (type.getNumElements() != 1)
79  return {};
80 
81  return {MemorySlot{getResult(), type.getElementType()}};
82 }
83 
84 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
85  OpBuilder &builder) {
86  assert(isSupportedElementType(slot.elemType));
87  // TODO: support more types.
89  .Case([&](MemRefType t) {
90  return builder.create<memref::AllocaOp>(getLoc(), t);
91  })
92  .Default([&](Type t) {
93  return builder.create<arith::ConstantOp>(getLoc(), t,
94  builder.getZeroAttr(t));
95  });
96 }
97 
98 std::optional<PromotableAllocationOpInterface>
99 memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100  Value defaultValue,
101  OpBuilder &builder) {
102  if (defaultValue.use_empty())
103  defaultValue.getDefiningOp()->erase();
104  this->erase();
105  return std::nullopt;
106 }
107 
108 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
109  BlockArgument argument,
110  OpBuilder &builder) {}
111 
113 memref::AllocaOp::getDestructurableSlots() {
114  MemRefType memrefType = getType();
115  auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
116  if (!destructurable)
117  return {};
118 
119  std::optional<DenseMap<Attribute, Type>> destructuredType =
120  destructurable.getSubelementIndexMap();
121  if (!destructuredType)
122  return {};
123 
124  return {
125  DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
126 }
127 
128 DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
129  const DestructurableMemorySlot &slot,
130  const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
132  builder.setInsertionPointAfter(*this);
133 
135 
136  auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
137  for (Attribute usedIndex : usedIndices) {
138  Type elemType = memrefType.getTypeAtIndex(usedIndex);
139  MemRefType elemPtr = MemRefType::get({}, elemType);
140  auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
141  newAllocators.push_back(subAlloca);
142  slotMap.try_emplace<MemorySlot>(usedIndex,
143  {subAlloca.getResult(), elemType});
144  }
145 
146  return slotMap;
147 }
148 
149 std::optional<DestructurableAllocationOpInterface>
150 memref::AllocaOp::handleDestructuringComplete(
151  const DestructurableMemorySlot &slot, OpBuilder &builder) {
152  assert(slot.ptr == getResult());
153  this->erase();
154  return std::nullopt;
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // Interfaces for LoadOp/StoreOp
159 //===----------------------------------------------------------------------===//
160 
161 bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
162  return getMemRef() == slot.ptr;
163 }
164 
165 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
166 
167 Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
168  Value reachingDef,
169  const DataLayout &dataLayout) {
170  llvm_unreachable("getStored should not be called on LoadOp");
171 }
172 
173 bool memref::LoadOp::canUsesBeRemoved(
174  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
175  SmallVectorImpl<OpOperand *> &newBlockingUses,
176  const DataLayout &dataLayout) {
177  if (blockingUses.size() != 1)
178  return false;
179  Value blockingUse = (*blockingUses.begin())->get();
180  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
181  getResult().getType() == slot.elemType;
182 }
183 
184 DeletionKind memref::LoadOp::removeBlockingUses(
185  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
186  OpBuilder &builder, Value reachingDefinition,
187  const DataLayout &dataLayout) {
188  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
189  // pointer.
190  getResult().replaceAllUsesWith(reachingDefinition);
191  return DeletionKind::Delete;
192 }
193 
194 /// Returns the index of a memref in attribute form, given its indices. Returns
195 /// a null pointer if whether the indices form a valid index for the provided
196 /// MemRefType cannot be computed. The indices must come from a valid memref
197 /// StoreOp or LoadOp.
199  ValueRange indices,
200  MemRefType memrefType) {
202  for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
203  IntegerAttr coordAttr;
204  if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
205  return {};
206  // MemRefType shape dimensions are always positive (checked by verifier).
207  std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
208  if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
209  return {};
210  index.push_back(coordAttr);
211  }
212  return ArrayAttr::get(ctx, index);
213 }
214 
215 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
216  SmallPtrSetImpl<Attribute> &usedIndices,
217  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
218  const DataLayout &dataLayout) {
219  if (slot.ptr != getMemRef())
220  return false;
223  if (!index)
224  return false;
225  usedIndices.insert(index);
226  return true;
227 }
228 
229 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
231  OpBuilder &builder,
232  const DataLayout &dataLayout) {
235  const MemorySlot &memorySlot = subslots.at(index);
236  setMemRef(memorySlot.ptr);
237  getIndicesMutable().clear();
238  return DeletionKind::Keep;
239 }
240 
241 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
242 
243 bool memref::StoreOp::storesTo(const MemorySlot &slot) {
244  return getMemRef() == slot.ptr;
245 }
246 
247 Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
248  Value reachingDef,
249  const DataLayout &dataLayout) {
250  return getValue();
251 }
252 
253 bool memref::StoreOp::canUsesBeRemoved(
254  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
255  SmallVectorImpl<OpOperand *> &newBlockingUses,
256  const DataLayout &dataLayout) {
257  if (blockingUses.size() != 1)
258  return false;
259  Value blockingUse = (*blockingUses.begin())->get();
260  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
261  getValue() != slot.ptr && getValue().getType() == slot.elemType;
262 }
263 
264 DeletionKind memref::StoreOp::removeBlockingUses(
265  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
266  OpBuilder &builder, Value reachingDefinition,
267  const DataLayout &dataLayout) {
268  return DeletionKind::Delete;
269 }
270 
271 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
272  SmallPtrSetImpl<Attribute> &usedIndices,
273  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
274  const DataLayout &dataLayout) {
275  if (slot.ptr != getMemRef() || getValue() == slot.ptr)
276  return false;
279  if (!index || !slot.subelementTypes.contains(index))
280  return false;
281  usedIndices.insert(index);
282  return true;
283 }
284 
285 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
287  OpBuilder &builder,
288  const DataLayout &dataLayout) {
291  const MemorySlot &memorySlot = subslots.at(index);
292  setMemRef(memorySlot.ptr);
293  getIndicesMutable().clear();
294  return DeletionKind::Keep;
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // Interfaces for destructurable types
299 //===----------------------------------------------------------------------===//
300 
301 namespace {
302 
303 struct MemRefDestructurableTypeExternalModel
304  : public DestructurableTypeInterface::ExternalModel<
305  MemRefDestructurableTypeExternalModel, MemRefType> {
306  std::optional<DenseMap<Attribute, Type>>
307  getSubelementIndexMap(Type type) const {
308  auto memrefType = llvm::cast<MemRefType>(type);
309  constexpr int64_t maxMemrefSizeForDestructuring = 16;
310  if (!memrefType.hasStaticShape() ||
311  memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312  memrefType.getNumElements() == 1)
313  return {};
314 
315  DenseMap<Attribute, Type> destructured;
317  memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
318  destructured.insert({index, memrefType.getElementType()});
319  });
320 
321  return destructured;
322  }
323 
324  Type getTypeAtIndex(Type type, Attribute index) const {
325  auto memrefType = llvm::cast<MemRefType>(type);
326  auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327  if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
328  return {};
329 
330  Type indexType = IndexType::get(memrefType.getContext());
331  for (const auto &[coordAttr, dimSize] :
332  llvm::zip(coordArrAttr, memrefType.getShape())) {
333  auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334  if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335  coord.getInt() >= dimSize)
336  return {};
337  }
338 
339  return memrefType.getElementType();
340  }
341 };
342 
343 } // namespace
344 
345 //===----------------------------------------------------------------------===//
346 // Register external models
347 //===----------------------------------------------------------------------===//
348 
350  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
351  MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
352  });
353 }
static MLIRContext * getContext(OpFoldResult val)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.