22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/ErrorHandling.h"
37 for (
size_t i = 0; i <
shape.size(); ++i) {
48template <
typename CallableT>
51 Type indexType = IndexType::get(ctx);
56 indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
57 walker(ArrayAttr::get(ctx, indexAsAttr));
67 if (!type.hasStaticShape())
70 if (type.getNumElements() != 1)
73 return {
MemorySlot{getResult(), type.getElementType()}};
78 return ub::PoisonOp::create(builder, getLoc(), slot.
elemType);
81std::optional<PromotableAllocationOpInterface>
82memref::AllocaOp::handlePromotionComplete(
const MemorySlot &slot,
91void memref::AllocaOp::handleBlockArgument(
const MemorySlot &slot,
96memref::AllocaOp::getDestructurableSlots() {
97 MemRefType memrefType =
getType();
98 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
102 std::optional<DenseMap<Attribute, Type>> destructuredType =
103 destructurable.getSubelementIndexMap();
104 if (!destructuredType)
119 auto memrefType = llvm::cast<DestructurableTypeInterface>(
getType());
120 for (
Attribute usedIndex : usedIndices) {
121 Type elemType = memrefType.getTypeAtIndex(usedIndex);
122 MemRefType elemPtr = MemRefType::get({}, elemType);
123 auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr);
124 newAllocators.push_back(subAlloca);
126 {subAlloca.getResult(), elemType});
132std::optional<DestructurableAllocationOpInterface>
133memref::AllocaOp::handleDestructuringComplete(
135 assert(slot.
ptr == getResult());
144bool memref::LoadOp::loadsFrom(
const MemorySlot &slot) {
148bool memref::LoadOp::storesTo(
const MemorySlot &slot) {
return false; }
153 llvm_unreachable(
"getStored should not be called on LoadOp");
156bool memref::LoadOp::canUsesBeRemoved(
160 if (blockingUses.size() != 1)
162 Value blockingUse = (*blockingUses.begin())->get();
173 getResult().replaceAllUsesWith(reachingDefinition);
183 MemRefType memrefType) {
185 for (
auto [coord, dimSize] : llvm::zip(
indices, memrefType.getShape())) {
186 IntegerAttr coordAttr;
190 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
191 if (!coordInt || coordInt.value() >=
static_cast<uint64_t
>(dimSize))
193 index.push_back(coordAttr);
195 return ArrayAttr::get(ctx,
index);
208 usedIndices.insert(
index);
219 setMemRef(memorySlot.
ptr);
220 getIndicesMutable().clear();
224bool memref::StoreOp::loadsFrom(
const MemorySlot &slot) {
return false; }
226bool memref::StoreOp::storesTo(
const MemorySlot &slot) {
236bool memref::StoreOp::canUsesBeRemoved(
240 if (blockingUses.size() != 1)
242 Value blockingUse = (*blockingUses.begin())->get();
264 usedIndices.insert(
index);
275 setMemRef(memorySlot.
ptr);
276 getIndicesMutable().clear();
286struct MemRefDestructurableTypeExternalModel
287 :
public DestructurableTypeInterface::ExternalModel<
288 MemRefDestructurableTypeExternalModel, MemRefType> {
289 std::optional<DenseMap<Attribute, Type>>
290 getSubelementIndexMap(Type type)
const {
291 auto memrefType = llvm::cast<MemRefType>(type);
292 constexpr int64_t maxMemrefSizeForDestructuring = 16;
293 if (!memrefType.hasStaticShape() ||
294 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
295 memrefType.getNumElements() == 1)
300 memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
301 destructured.insert({index, memrefType.getElementType()});
308 auto memrefType = llvm::cast<MemRefType>(type);
309 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(
index);
310 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
313 Type indexType = IndexType::get(memrefType.getContext());
314 for (
const auto &[coordAttr, dimSize] :
315 llvm::zip(coordArrAttr, memrefType.getShape())) {
316 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
317 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
318 coord.getInt() >= dimSize)
322 return memrefType.getElementType();
334 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
void erase()
Remove this operation from its parent block and delete it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void registerMemorySlotExternalModels(DialectRegistry ®istry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.