MLIR  19.0.0git
MemRefMemorySlot.cpp
Go to the documentation of this file.
1 //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Mem2Reg-related interfaces for MemRef dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/Value.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utilities
32 //===----------------------------------------------------------------------===//
33 
34 /// Walks over the indices of the elements of a tensor of a given `shape` by
35 /// updating `index` in place to the next index. This returns failure if the
36 /// provided index was the last index.
39  for (size_t i = 0; i < shape.size(); ++i) {
40  index[i]++;
41  if (index[i] < shape[i])
42  return success();
43  index[i] = 0;
44  }
45  return failure();
46 }
47 
48 /// Calls `walker` for each index within a tensor of a given `shape`, providing
49 /// the index as an array attribute of the coordinates.
50 template <typename CallableT>
52  CallableT &&walker) {
53  Type indexType = IndexType::get(ctx);
54  SmallVector<int64_t> shapeIter(shape.size(), 0);
55  do {
56  SmallVector<Attribute> indexAsAttr;
57  for (int64_t dim : shapeIter)
58  indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
59  walker(ArrayAttr::get(ctx, indexAsAttr));
60  } while (succeeded(nextIndex(shape, shapeIter)));
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Interfaces for AllocaOp
65 //===----------------------------------------------------------------------===//
66 
67 static bool isSupportedElementType(Type type) {
68  return llvm::isa<MemRefType>(type) ||
69  OpBuilder(type.getContext()).getZeroAttr(type);
70 }
71 
72 SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
73  MemRefType type = getType();
74  if (!isSupportedElementType(type.getElementType()))
75  return {};
76  if (!type.hasStaticShape())
77  return {};
78  // Make sure the memref contains only a single element.
79  if (type.getNumElements() != 1)
80  return {};
81 
82  return {MemorySlot{getResult(), type.getElementType()}};
83 }
84 
85 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
86  OpBuilder &builder) {
87  assert(isSupportedElementType(slot.elemType));
88  // TODO: support more types.
90  .Case([&](MemRefType t) {
91  return builder.create<memref::AllocaOp>(getLoc(), t);
92  })
93  .Default([&](Type t) {
94  return builder.create<arith::ConstantOp>(getLoc(), t,
95  builder.getZeroAttr(t));
96  });
97 }
98 
99 std::optional<PromotableAllocationOpInterface>
100 memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
101  Value defaultValue,
102  OpBuilder &builder) {
103  if (defaultValue.use_empty())
104  defaultValue.getDefiningOp()->erase();
105  this->erase();
106  return std::nullopt;
107 }
108 
109 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
110  BlockArgument argument,
111  OpBuilder &builder) {}
112 
114 memref::AllocaOp::getDestructurableSlots() {
115  MemRefType memrefType = getType();
116  auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
117  if (!destructurable)
118  return {};
119 
120  std::optional<DenseMap<Attribute, Type>> destructuredType =
121  destructurable.getSubelementIndexMap();
122  if (!destructuredType)
123  return {};
124 
125  return {
126  DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
127 }
128 
129 DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
130  const DestructurableMemorySlot &slot,
131  const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
133  builder.setInsertionPointAfter(*this);
134 
136 
137  auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
138  for (Attribute usedIndex : usedIndices) {
139  Type elemType = memrefType.getTypeAtIndex(usedIndex);
140  MemRefType elemPtr = MemRefType::get({}, elemType);
141  auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
142  newAllocators.push_back(subAlloca);
143  slotMap.try_emplace<MemorySlot>(usedIndex,
144  {subAlloca.getResult(), elemType});
145  }
146 
147  return slotMap;
148 }
149 
150 std::optional<DestructurableAllocationOpInterface>
151 memref::AllocaOp::handleDestructuringComplete(
152  const DestructurableMemorySlot &slot, OpBuilder &builder) {
153  assert(slot.ptr == getResult());
154  this->erase();
155  return std::nullopt;
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Interfaces for LoadOp/StoreOp
160 //===----------------------------------------------------------------------===//
161 
162 bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
163  return getMemRef() == slot.ptr;
164 }
165 
166 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
167 
168 Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
169  Value reachingDef,
170  const DataLayout &dataLayout) {
171  llvm_unreachable("getStored should not be called on LoadOp");
172 }
173 
174 bool memref::LoadOp::canUsesBeRemoved(
175  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
176  SmallVectorImpl<OpOperand *> &newBlockingUses,
177  const DataLayout &dataLayout) {
178  if (blockingUses.size() != 1)
179  return false;
180  Value blockingUse = (*blockingUses.begin())->get();
181  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
182  getResult().getType() == slot.elemType;
183 }
184 
185 DeletionKind memref::LoadOp::removeBlockingUses(
186  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
187  OpBuilder &builder, Value reachingDefinition,
188  const DataLayout &dataLayout) {
189  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
190  // pointer.
191  getResult().replaceAllUsesWith(reachingDefinition);
192  return DeletionKind::Delete;
193 }
194 
195 /// Returns the index of a memref in attribute form, given its indices. Returns
196 /// a null pointer if whether the indices form a valid index for the provided
197 /// MemRefType cannot be computed. The indices must come from a valid memref
198 /// StoreOp or LoadOp.
200  ValueRange indices,
201  MemRefType memrefType) {
203  for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
204  IntegerAttr coordAttr;
205  if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
206  return {};
207  // MemRefType shape dimensions are always positive (checked by verifier).
208  std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
209  if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
210  return {};
211  index.push_back(coordAttr);
212  }
213  return ArrayAttr::get(ctx, index);
214 }
215 
216 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
217  SmallPtrSetImpl<Attribute> &usedIndices,
218  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
219  const DataLayout &dataLayout) {
220  if (slot.ptr != getMemRef())
221  return false;
224  if (!index)
225  return false;
226  usedIndices.insert(index);
227  return true;
228 }
229 
230 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
232  OpBuilder &builder,
233  const DataLayout &dataLayout) {
236  const MemorySlot &memorySlot = subslots.at(index);
237  setMemRef(memorySlot.ptr);
238  getIndicesMutable().clear();
239  return DeletionKind::Keep;
240 }
241 
242 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
243 
244 bool memref::StoreOp::storesTo(const MemorySlot &slot) {
245  return getMemRef() == slot.ptr;
246 }
247 
248 Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
249  Value reachingDef,
250  const DataLayout &dataLayout) {
251  return getValue();
252 }
253 
254 bool memref::StoreOp::canUsesBeRemoved(
255  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
256  SmallVectorImpl<OpOperand *> &newBlockingUses,
257  const DataLayout &dataLayout) {
258  if (blockingUses.size() != 1)
259  return false;
260  Value blockingUse = (*blockingUses.begin())->get();
261  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
262  getValue() != slot.ptr && getValue().getType() == slot.elemType;
263 }
264 
265 DeletionKind memref::StoreOp::removeBlockingUses(
266  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
267  OpBuilder &builder, Value reachingDefinition,
268  const DataLayout &dataLayout) {
269  return DeletionKind::Delete;
270 }
271 
272 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
273  SmallPtrSetImpl<Attribute> &usedIndices,
274  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
275  const DataLayout &dataLayout) {
276  if (slot.ptr != getMemRef() || getValue() == slot.ptr)
277  return false;
280  if (!index || !slot.elementPtrs.contains(index))
281  return false;
282  usedIndices.insert(index);
283  return true;
284 }
285 
286 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
288  OpBuilder &builder,
289  const DataLayout &dataLayout) {
292  const MemorySlot &memorySlot = subslots.at(index);
293  setMemRef(memorySlot.ptr);
294  getIndicesMutable().clear();
295  return DeletionKind::Keep;
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Interfaces for destructurable types
300 //===----------------------------------------------------------------------===//
301 
302 namespace {
303 
304 struct MemRefDestructurableTypeExternalModel
305  : public DestructurableTypeInterface::ExternalModel<
306  MemRefDestructurableTypeExternalModel, MemRefType> {
307  std::optional<DenseMap<Attribute, Type>>
308  getSubelementIndexMap(Type type) const {
309  auto memrefType = llvm::cast<MemRefType>(type);
310  constexpr int64_t maxMemrefSizeForDestructuring = 16;
311  if (!memrefType.hasStaticShape() ||
312  memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
313  memrefType.getNumElements() == 1)
314  return {};
315 
316  DenseMap<Attribute, Type> destructured;
318  memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
319  destructured.insert({index, memrefType.getElementType()});
320  });
321 
322  return destructured;
323  }
324 
325  Type getTypeAtIndex(Type type, Attribute index) const {
326  auto memrefType = llvm::cast<MemRefType>(type);
327  auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
328  if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
329  return {};
330 
331  Type indexType = IndexType::get(memrefType.getContext());
332  for (const auto &[coordAttr, dimSize] :
333  llvm::zip(coordArrAttr, memrefType.getShape())) {
334  auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
335  if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
336  coord.getInt() >= dimSize)
337  return {};
338  }
339 
340  return memrefType.getElementType();
341  }
342 };
343 
344 } // namespace
345 
346 //===----------------------------------------------------------------------===//
347 // Register external models
348 //===----------------------------------------------------------------------===//
349 
351  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
352  MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
353  });
354 }
static MLIRContext * getContext(OpFoldResult val)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the corresponding subelement type.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.