23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
38 for (
size_t i = 0; i < shape.size(); ++i) {
40 if (index[i] < shape[i])
49 template <
typename CallableT>
56 for (int64_t dim : shapeIter)
59 }
while (succeeded(
nextIndex(shape, shapeIter)));
67 return llvm::isa<MemRefType>(type) ||
75 if (!type.hasStaticShape())
78 if (type.getNumElements() != 1)
81 return {
MemorySlot{getResult(), type.getElementType()}};
89 .Case([&](MemRefType t) {
90 return builder.
create<memref::AllocaOp>(getLoc(), t);
92 .Default([&](
Type t) {
93 return builder.
create<arith::ConstantOp>(getLoc(), t,
98 std::optional<PromotableAllocationOpInterface>
99 memref::AllocaOp::handlePromotionComplete(
const MemorySlot &slot,
108 void memref::AllocaOp::handleBlockArgument(
const MemorySlot &slot,
113 memref::AllocaOp::getDestructurableSlots() {
114 MemRefType memrefType =
getType();
115 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
119 std::optional<DenseMap<Attribute, Type>> destructuredType =
120 destructurable.getSubelementIndexMap();
121 if (!destructuredType)
136 auto memrefType = llvm::cast<DestructurableTypeInterface>(
getType());
137 for (
Attribute usedIndex : usedIndices) {
138 Type elemType = memrefType.getTypeAtIndex(usedIndex);
140 auto subAlloca = builder.
create<memref::AllocaOp>(getLoc(), elemPtr);
141 newAllocators.push_back(subAlloca);
143 {subAlloca.getResult(), elemType});
149 std::optional<DestructurableAllocationOpInterface>
150 memref::AllocaOp::handleDestructuringComplete(
152 assert(slot.
ptr == getResult());
161 bool memref::LoadOp::loadsFrom(
const MemorySlot &slot) {
162 return getMemRef() == slot.
ptr;
165 bool memref::LoadOp::storesTo(
const MemorySlot &slot) {
return false; }
170 llvm_unreachable(
"getStored should not be called on LoadOp");
173 bool memref::LoadOp::canUsesBeRemoved(
177 if (blockingUses.size() != 1)
179 Value blockingUse = (*blockingUses.begin())->
get();
180 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
190 getResult().replaceAllUsesWith(reachingDefinition);
200 MemRefType memrefType) {
202 for (
auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
203 IntegerAttr coordAttr;
204 if (!
matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
207 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
208 if (!coordInt || coordInt.value() >=
static_cast<uint64_t
>(dimSize))
210 index.push_back(coordAttr);
219 if (slot.
ptr != getMemRef())
225 usedIndices.insert(index);
235 const MemorySlot &memorySlot = subslots.at(index);
236 setMemRef(memorySlot.
ptr);
237 getIndicesMutable().clear();
241 bool memref::StoreOp::loadsFrom(
const MemorySlot &slot) {
return false; }
243 bool memref::StoreOp::storesTo(
const MemorySlot &slot) {
244 return getMemRef() == slot.
ptr;
253 bool memref::StoreOp::canUsesBeRemoved(
257 if (blockingUses.size() != 1)
259 Value blockingUse = (*blockingUses.begin())->
get();
260 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
275 if (slot.
ptr != getMemRef() || getValue() == slot.
ptr)
281 usedIndices.insert(index);
291 const MemorySlot &memorySlot = subslots.at(index);
292 setMemRef(memorySlot.
ptr);
293 getIndicesMutable().clear();
303 struct MemRefDestructurableTypeExternalModel
304 :
public DestructurableTypeInterface::ExternalModel<
305 MemRefDestructurableTypeExternalModel, MemRefType> {
306 std::optional<DenseMap<Attribute, Type>>
307 getSubelementIndexMap(
Type type)
const {
308 auto memrefType = llvm::cast<MemRefType>(type);
309 constexpr int64_t maxMemrefSizeForDestructuring = 16;
310 if (!memrefType.hasStaticShape() ||
311 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312 memrefType.getNumElements() == 1)
317 memrefType.getContext(), memrefType.getShape(), [&](
Attribute index) {
318 destructured.insert({index, memrefType.getElementType()});
325 auto memrefType = llvm::cast<MemRefType>(type);
326 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
331 for (
const auto &[coordAttr, dimSize] :
332 llvm::zip(coordArrAttr, memrefType.getShape())) {
333 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335 coord.getInt() >= dimSize)
339 return memrefType.getElementType();
351 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
static MLIRContext * getContext(OpFoldResult val)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
void erase()
Remove this operation from its parent block and delete it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry ®istry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.