MLIR  19.0.0git
MemRefMemorySlot.cpp
Go to the documentation of this file.
1 //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Mem2Reg-related interfaces for MemRef dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/Value.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utilities
32 //===----------------------------------------------------------------------===//
33 
34 /// Walks over the indices of the elements of a tensor of a given `shape` by
35 /// updating `index` in place to the next index. This returns failure if the
36 /// provided index was the last index.
39  for (size_t i = 0; i < shape.size(); ++i) {
40  index[i]++;
41  if (index[i] < shape[i])
42  return success();
43  index[i] = 0;
44  }
45  return failure();
46 }
47 
48 /// Calls `walker` for each index within a tensor of a given `shape`, providing
49 /// the index as an array attribute of the coordinates.
50 template <typename CallableT>
52  CallableT &&walker) {
53  Type indexType = IndexType::get(ctx);
54  SmallVector<int64_t> shapeIter(shape.size(), 0);
55  do {
56  SmallVector<Attribute> indexAsAttr;
57  for (int64_t dim : shapeIter)
58  indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
59  walker(ArrayAttr::get(ctx, indexAsAttr));
60  } while (succeeded(nextIndex(shape, shapeIter)));
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Interfaces for AllocaOp
65 //===----------------------------------------------------------------------===//
66 
67 static bool isSupportedElementType(Type type) {
68  return llvm::isa<MemRefType>(type) ||
69  OpBuilder(type.getContext()).getZeroAttr(type);
70 }
71 
72 SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
73  MemRefType type = getType();
74  if (!isSupportedElementType(type.getElementType()))
75  return {};
76  if (!type.hasStaticShape())
77  return {};
78  // Make sure the memref contains only a single element.
79  if (type.getNumElements() != 1)
80  return {};
81 
82  return {MemorySlot{getResult(), type.getElementType()}};
83 }
84 
85 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
86  RewriterBase &rewriter) {
87  assert(isSupportedElementType(slot.elemType));
88  // TODO: support more types.
90  .Case([&](MemRefType t) {
91  return rewriter.create<memref::AllocaOp>(getLoc(), t);
92  })
93  .Default([&](Type t) {
94  return rewriter.create<arith::ConstantOp>(getLoc(), t,
95  rewriter.getZeroAttr(t));
96  });
97 }
98 
99 void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100  Value defaultValue,
101  RewriterBase &rewriter) {
102  if (defaultValue.use_empty())
103  rewriter.eraseOp(defaultValue.getDefiningOp());
104  rewriter.eraseOp(*this);
105 }
106 
107 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
108  BlockArgument argument,
109  RewriterBase &rewriter) {}
110 
112 memref::AllocaOp::getDestructurableSlots() {
113  MemRefType memrefType = getType();
114  auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
115  if (!destructurable)
116  return {};
117 
118  std::optional<DenseMap<Attribute, Type>> destructuredType =
119  destructurable.getSubelementIndexMap();
120  if (!destructuredType)
121  return {};
122 
123  return {
124  DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
125 }
126 
128 memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
129  const SmallPtrSetImpl<Attribute> &usedIndices,
130  RewriterBase &rewriter) {
131  rewriter.setInsertionPointAfter(*this);
132 
134 
135  auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
136  for (Attribute usedIndex : usedIndices) {
137  Type elemType = memrefType.getTypeAtIndex(usedIndex);
138  MemRefType elemPtr = MemRefType::get({}, elemType);
139  auto subAlloca = rewriter.create<memref::AllocaOp>(getLoc(), elemPtr);
140  slotMap.try_emplace<MemorySlot>(usedIndex,
141  {subAlloca.getResult(), elemType});
142  }
143 
144  return slotMap;
145 }
146 
147 void memref::AllocaOp::handleDestructuringComplete(
148  const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
149  assert(slot.ptr == getResult());
150  rewriter.eraseOp(*this);
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // Interfaces for LoadOp/StoreOp
155 //===----------------------------------------------------------------------===//
156 
157 bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
158  return getMemRef() == slot.ptr;
159 }
160 
161 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
162 
163 Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
164  Value reachingDef,
165  const DataLayout &dataLayout) {
166  llvm_unreachable("getStored should not be called on LoadOp");
167 }
168 
169 bool memref::LoadOp::canUsesBeRemoved(
170  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
171  SmallVectorImpl<OpOperand *> &newBlockingUses,
172  const DataLayout &dataLayout) {
173  if (blockingUses.size() != 1)
174  return false;
175  Value blockingUse = (*blockingUses.begin())->get();
176  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
177  getResult().getType() == slot.elemType;
178 }
179 
180 DeletionKind memref::LoadOp::removeBlockingUses(
181  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
182  RewriterBase &rewriter, Value reachingDefinition,
183  const DataLayout &dataLayout) {
184  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
185  // pointer.
186  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
187  return DeletionKind::Delete;
188 }
189 
190 /// Returns the index of a memref in attribute form, given its indices. Returns
191 /// a null pointer if whether the indices form a valid index for the provided
192 /// MemRefType cannot be computed. The indices must come from a valid memref
193 /// StoreOp or LoadOp.
195  ValueRange indices,
196  MemRefType memrefType) {
198  for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
199  IntegerAttr coordAttr;
200  if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
201  return {};
202  // MemRefType shape dimensions are always positive (checked by verifier).
203  std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
204  if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
205  return {};
206  index.push_back(coordAttr);
207  }
208  return ArrayAttr::get(ctx, index);
209 }
210 
211 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
212  SmallPtrSetImpl<Attribute> &usedIndices,
213  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
214  const DataLayout &dataLayout) {
215  if (slot.ptr != getMemRef())
216  return false;
219  if (!index)
220  return false;
221  usedIndices.insert(index);
222  return true;
223 }
224 
225 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
227  RewriterBase &rewriter,
228  const DataLayout &dataLayout) {
231  const MemorySlot &memorySlot = subslots.at(index);
232  rewriter.modifyOpInPlace(*this, [&]() {
233  setMemRef(memorySlot.ptr);
234  getIndicesMutable().clear();
235  });
236  return DeletionKind::Keep;
237 }
238 
239 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
240 
241 bool memref::StoreOp::storesTo(const MemorySlot &slot) {
242  return getMemRef() == slot.ptr;
243 }
244 
245 Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
246  Value reachingDef,
247  const DataLayout &dataLayout) {
248  return getValue();
249 }
250 
251 bool memref::StoreOp::canUsesBeRemoved(
252  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
253  SmallVectorImpl<OpOperand *> &newBlockingUses,
254  const DataLayout &dataLayout) {
255  if (blockingUses.size() != 1)
256  return false;
257  Value blockingUse = (*blockingUses.begin())->get();
258  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
259  getValue() != slot.ptr && getValue().getType() == slot.elemType;
260 }
261 
262 DeletionKind memref::StoreOp::removeBlockingUses(
263  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
264  RewriterBase &rewriter, Value reachingDefinition,
265  const DataLayout &dataLayout) {
266  return DeletionKind::Delete;
267 }
268 
269 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
270  SmallPtrSetImpl<Attribute> &usedIndices,
271  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
272  const DataLayout &dataLayout) {
273  if (slot.ptr != getMemRef() || getValue() == slot.ptr)
274  return false;
277  if (!index || !slot.elementPtrs.contains(index))
278  return false;
279  usedIndices.insert(index);
280  return true;
281 }
282 
283 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
285  RewriterBase &rewriter,
286  const DataLayout &dataLayout) {
289  const MemorySlot &memorySlot = subslots.at(index);
290  rewriter.modifyOpInPlace(*this, [&]() {
291  setMemRef(memorySlot.ptr);
292  getIndicesMutable().clear();
293  });
294  return DeletionKind::Keep;
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // Interfaces for destructurable types
299 //===----------------------------------------------------------------------===//
300 
301 namespace {
302 
303 struct MemRefDestructurableTypeExternalModel
304  : public DestructurableTypeInterface::ExternalModel<
305  MemRefDestructurableTypeExternalModel, MemRefType> {
306  std::optional<DenseMap<Attribute, Type>>
307  getSubelementIndexMap(Type type) const {
308  auto memrefType = llvm::cast<MemRefType>(type);
309  constexpr int64_t maxMemrefSizeForDestructuring = 16;
310  if (!memrefType.hasStaticShape() ||
311  memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312  memrefType.getNumElements() == 1)
313  return {};
314 
315  DenseMap<Attribute, Type> destructured;
317  memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
318  destructured.insert({index, memrefType.getElementType()});
319  });
320 
321  return destructured;
322  }
323 
324  Type getTypeAtIndex(Type type, Attribute index) const {
325  auto memrefType = llvm::cast<MemRefType>(type);
326  auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327  if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
328  return {};
329 
330  Type indexType = IndexType::get(memrefType.getContext());
331  for (const auto &[coordAttr, dimSize] :
332  llvm::zip(coordArrAttr, memrefType.getShape())) {
333  auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334  if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335  coord.getInt() >= dimSize)
336  return {};
337  }
338 
339  return memrefType.getElementType();
340  }
341 };
342 
343 } // namespace
344 
345 //===----------------------------------------------------------------------===//
346 // Register external models
347 //===----------------------------------------------------------------------===//
348 
350  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
351  MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
352  });
353 }
static MLIRContext * getContext(OpFoldResult val)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the corresponding subelement type.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.