MLIR 22.0.0git
MemRefMemorySlot.cpp
Go to the documentation of this file.
1//===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Mem2Reg-related interfaces for MemRef dialect
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
19#include "mlir/IR/Matchers.h"
20#include "mlir/IR/Value.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/ErrorHandling.h"
25
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// Utilities
30//===----------------------------------------------------------------------===//
31
32/// Walks over the indices of the elements of a tensor of a given `shape` by
33/// updating `index` in place to the next index. This returns failure if the
34/// provided index was the last index.
35static LogicalResult nextIndex(ArrayRef<int64_t> shape,
37 for (size_t i = 0; i < shape.size(); ++i) {
38 index[i]++;
39 if (index[i] < shape[i])
40 return success();
41 index[i] = 0;
42 }
43 return failure();
44}
45
46/// Calls `walker` for each index within a tensor of a given `shape`, providing
47/// the index as an array attribute of the coordinates.
48template <typename CallableT>
50 CallableT &&walker) {
51 Type indexType = IndexType::get(ctx);
52 SmallVector<int64_t> shapeIter(shape.size(), 0);
53 do {
54 SmallVector<Attribute> indexAsAttr;
55 for (int64_t dim : shapeIter)
56 indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
57 walker(ArrayAttr::get(ctx, indexAsAttr));
58 } while (succeeded(nextIndex(shape, shapeIter)));
59}
60
61//===----------------------------------------------------------------------===//
62// Interfaces for AllocaOp
63//===----------------------------------------------------------------------===//
64
65SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
66 MemRefType type = getType();
67 if (!type.hasStaticShape())
68 return {};
69 // Make sure the memref contains only a single element.
70 if (type.getNumElements() != 1)
71 return {};
72
73 return {MemorySlot{getResult(), type.getElementType()}};
74}
75
76Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
77 OpBuilder &builder) {
78 return ub::PoisonOp::create(builder, getLoc(), slot.elemType);
79}
80
81std::optional<PromotableAllocationOpInterface>
82memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
83 Value defaultValue,
84 OpBuilder &builder) {
85 if (defaultValue.use_empty())
86 defaultValue.getDefiningOp()->erase();
87 this->erase();
88 return std::nullopt;
89}
90
91void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
92 BlockArgument argument,
93 OpBuilder &builder) {}
94
96memref::AllocaOp::getDestructurableSlots() {
97 MemRefType memrefType = getType();
98 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
99 if (!destructurable)
100 return {};
101
102 std::optional<DenseMap<Attribute, Type>> destructuredType =
103 destructurable.getSubelementIndexMap();
104 if (!destructuredType)
105 return {};
106
107 return {
108 DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
109}
110
111DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
112 const DestructurableMemorySlot &slot,
113 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
115 builder.setInsertionPointAfter(*this);
116
118
119 auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
120 for (Attribute usedIndex : usedIndices) {
121 Type elemType = memrefType.getTypeAtIndex(usedIndex);
122 MemRefType elemPtr = MemRefType::get({}, elemType);
123 auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr);
124 newAllocators.push_back(subAlloca);
125 slotMap.try_emplace<MemorySlot>(usedIndex,
126 {subAlloca.getResult(), elemType});
127 }
128
129 return slotMap;
130}
131
132std::optional<DestructurableAllocationOpInterface>
133memref::AllocaOp::handleDestructuringComplete(
134 const DestructurableMemorySlot &slot, OpBuilder &builder) {
135 assert(slot.ptr == getResult());
136 this->erase();
137 return std::nullopt;
138}
139
140//===----------------------------------------------------------------------===//
141// Interfaces for LoadOp/StoreOp
142//===----------------------------------------------------------------------===//
143
144bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
145 return getMemRef() == slot.ptr;
146}
147
148bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
149
150Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
151 Value reachingDef,
152 const DataLayout &dataLayout) {
153 llvm_unreachable("getStored should not be called on LoadOp");
154}
155
156bool memref::LoadOp::canUsesBeRemoved(
157 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
158 SmallVectorImpl<OpOperand *> &newBlockingUses,
159 const DataLayout &dataLayout) {
160 if (blockingUses.size() != 1)
161 return false;
162 Value blockingUse = (*blockingUses.begin())->get();
163 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
164 getResult().getType() == slot.elemType;
165}
166
167DeletionKind memref::LoadOp::removeBlockingUses(
168 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
169 OpBuilder &builder, Value reachingDefinition,
170 const DataLayout &dataLayout) {
171 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
172 // pointer.
173 getResult().replaceAllUsesWith(reachingDefinition);
175}
176
177/// Returns the index of a memref in attribute form, given its indices. Returns
178/// a null pointer if whether the indices form a valid index for the provided
179/// MemRefType cannot be computed. The indices must come from a valid memref
180/// StoreOp or LoadOp.
183 MemRefType memrefType) {
185 for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
186 IntegerAttr coordAttr;
187 if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
188 return {};
189 // MemRefType shape dimensions are always positive (checked by verifier).
190 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
191 if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
192 return {};
193 index.push_back(coordAttr);
194 }
195 return ArrayAttr::get(ctx, index);
196}
197
198bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
199 SmallPtrSetImpl<Attribute> &usedIndices,
200 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
201 const DataLayout &dataLayout) {
202 if (slot.ptr != getMemRef())
203 return false;
206 if (!index)
207 return false;
208 usedIndices.insert(index);
209 return true;
210}
211
212DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
214 OpBuilder &builder,
215 const DataLayout &dataLayout) {
218 const MemorySlot &memorySlot = subslots.at(index);
219 setMemRef(memorySlot.ptr);
220 getIndicesMutable().clear();
221 return DeletionKind::Keep;
222}
223
224bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
225
226bool memref::StoreOp::storesTo(const MemorySlot &slot) {
227 return getMemRef() == slot.ptr;
228}
229
230Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
231 Value reachingDef,
232 const DataLayout &dataLayout) {
233 return getValue();
234}
235
236bool memref::StoreOp::canUsesBeRemoved(
237 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
238 SmallVectorImpl<OpOperand *> &newBlockingUses,
239 const DataLayout &dataLayout) {
240 if (blockingUses.size() != 1)
241 return false;
242 Value blockingUse = (*blockingUses.begin())->get();
243 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
244 getValue() != slot.ptr && getValue().getType() == slot.elemType;
245}
246
247DeletionKind memref::StoreOp::removeBlockingUses(
248 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
249 OpBuilder &builder, Value reachingDefinition,
250 const DataLayout &dataLayout) {
252}
253
254bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
255 SmallPtrSetImpl<Attribute> &usedIndices,
256 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
257 const DataLayout &dataLayout) {
258 if (slot.ptr != getMemRef() || getValue() == slot.ptr)
259 return false;
262 if (!index || !slot.subelementTypes.contains(index))
263 return false;
264 usedIndices.insert(index);
265 return true;
266}
267
268DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
270 OpBuilder &builder,
271 const DataLayout &dataLayout) {
274 const MemorySlot &memorySlot = subslots.at(index);
275 setMemRef(memorySlot.ptr);
276 getIndicesMutable().clear();
277 return DeletionKind::Keep;
278}
279
280//===----------------------------------------------------------------------===//
281// Interfaces for destructurable types
282//===----------------------------------------------------------------------===//
283
284namespace {
285
286struct MemRefDestructurableTypeExternalModel
287 : public DestructurableTypeInterface::ExternalModel<
288 MemRefDestructurableTypeExternalModel, MemRefType> {
289 std::optional<DenseMap<Attribute, Type>>
290 getSubelementIndexMap(Type type) const {
291 auto memrefType = llvm::cast<MemRefType>(type);
292 constexpr int64_t maxMemrefSizeForDestructuring = 16;
293 if (!memrefType.hasStaticShape() ||
294 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
295 memrefType.getNumElements() == 1)
296 return {};
297
298 DenseMap<Attribute, Type> destructured;
300 memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
301 destructured.insert({index, memrefType.getElementType()});
302 });
303
304 return destructured;
305 }
306
307 Type getTypeAtIndex(Type type, Attribute index) const {
308 auto memrefType = llvm::cast<MemRefType>(type);
309 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
310 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
311 return {};
312
313 Type indexType = IndexType::get(memrefType.getContext());
314 for (const auto &[coordAttr, dimSize] :
315 llvm::zip(coordArrAttr, memrefType.getShape())) {
316 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
317 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
318 coord.getInt() >= dimSize)
319 return {};
320 }
321
322 return memrefType.getElementType();
323 }
324};
325
326} // namespace
327
328//===----------------------------------------------------------------------===//
329// Register external models
330//===----------------------------------------------------------------------===//
331
333 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
334 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
335 });
336}
return success()
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition Utils.cpp:246
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
b getContext())
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
void erase()
Remove this operation from its parent block and delete it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void registerMemorySlotExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.