21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/ErrorHandling.h"
36 for (
size_t i = 0; i < shape.size(); ++i) {
38 if (index[i] < shape[i])
47 template <
typename CallableT>
54 for (int64_t dim : shapeIter)
57 }
while (succeeded(
nextIndex(shape, shapeIter)));
65 return llvm::isa<MemRefType>(type) ||
73 if (!type.hasStaticShape())
76 if (type.getNumElements() != 1)
79 return {
MemorySlot{getResult(), type.getElementType()}};
87 .Case([&](MemRefType t) {
88 return memref::AllocaOp::create(builder, getLoc(), t);
90 .Default([&](
Type t) {
91 return arith::ConstantOp::create(builder, getLoc(), t,
96 std::optional<PromotableAllocationOpInterface>
97 memref::AllocaOp::handlePromotionComplete(
const MemorySlot &slot,
106 void memref::AllocaOp::handleBlockArgument(
const MemorySlot &slot,
111 memref::AllocaOp::getDestructurableSlots() {
112 MemRefType memrefType =
getType();
113 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
117 std::optional<DenseMap<Attribute, Type>> destructuredType =
118 destructurable.getSubelementIndexMap();
119 if (!destructuredType)
134 auto memrefType = llvm::cast<DestructurableTypeInterface>(
getType());
135 for (
Attribute usedIndex : usedIndices) {
136 Type elemType = memrefType.getTypeAtIndex(usedIndex);
138 auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr);
139 newAllocators.push_back(subAlloca);
141 {subAlloca.getResult(), elemType});
147 std::optional<DestructurableAllocationOpInterface>
148 memref::AllocaOp::handleDestructuringComplete(
150 assert(slot.
ptr == getResult());
159 bool memref::LoadOp::loadsFrom(
const MemorySlot &slot) {
160 return getMemRef() == slot.
ptr;
163 bool memref::LoadOp::storesTo(
const MemorySlot &slot) {
return false; }
168 llvm_unreachable(
"getStored should not be called on LoadOp");
171 bool memref::LoadOp::canUsesBeRemoved(
175 if (blockingUses.size() != 1)
177 Value blockingUse = (*blockingUses.begin())->
get();
178 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
188 getResult().replaceAllUsesWith(reachingDefinition);
198 MemRefType memrefType) {
200 for (
auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
201 IntegerAttr coordAttr;
202 if (!
matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
205 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
206 if (!coordInt || coordInt.value() >=
static_cast<uint64_t
>(dimSize))
208 index.push_back(coordAttr);
217 if (slot.
ptr != getMemRef())
223 usedIndices.insert(index);
233 const MemorySlot &memorySlot = subslots.at(index);
234 setMemRef(memorySlot.
ptr);
235 getIndicesMutable().clear();
239 bool memref::StoreOp::loadsFrom(
const MemorySlot &slot) {
return false; }
241 bool memref::StoreOp::storesTo(
const MemorySlot &slot) {
242 return getMemRef() == slot.
ptr;
251 bool memref::StoreOp::canUsesBeRemoved(
255 if (blockingUses.size() != 1)
257 Value blockingUse = (*blockingUses.begin())->
get();
258 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
273 if (slot.
ptr != getMemRef() || getValue() == slot.
ptr)
279 usedIndices.insert(index);
289 const MemorySlot &memorySlot = subslots.at(index);
290 setMemRef(memorySlot.
ptr);
291 getIndicesMutable().clear();
301 struct MemRefDestructurableTypeExternalModel
302 :
public DestructurableTypeInterface::ExternalModel<
303 MemRefDestructurableTypeExternalModel, MemRefType> {
304 std::optional<DenseMap<Attribute, Type>>
305 getSubelementIndexMap(
Type type)
const {
306 auto memrefType = llvm::cast<MemRefType>(type);
307 constexpr int64_t maxMemrefSizeForDestructuring = 16;
308 if (!memrefType.hasStaticShape() ||
309 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
310 memrefType.getNumElements() == 1)
315 memrefType.getContext(), memrefType.getShape(), [&](
Attribute index) {
316 destructured.insert({index, memrefType.getElementType()});
323 auto memrefType = llvm::cast<MemRefType>(type);
324 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
325 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
329 for (
const auto &[coordAttr, dimSize] :
330 llvm::zip(coordArrAttr, memrefType.getShape())) {
331 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
332 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
333 coord.getInt() >= dimSize)
337 return memrefType.getElementType();
349 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
static MLIRContext * getContext(OpFoldResult val)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
void erase()
Remove this operation from its parent block and delete it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry ®istry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.