24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
39 for (
size_t i = 0; i < shape.size(); ++i) {
41 if (index[i] < shape[i])
50 template <
typename CallableT>
57 for (int64_t dim : shapeIter)
68 return llvm::isa<MemRefType>(type) ||
73 MemRefType type = getType();
76 if (!type.hasStaticShape())
79 if (type.getNumElements() != 1)
82 return {
MemorySlot{getResult(), type.getElementType()}};
90 .Case([&](MemRefType t) {
91 return rewriter.
create<memref::AllocaOp>(getLoc(), t);
93 .Default([&](
Type t) {
94 return rewriter.
create<arith::ConstantOp>(getLoc(), t,
99 void memref::AllocaOp::handlePromotionComplete(
const MemorySlot &slot,
107 void memref::AllocaOp::handleBlockArgument(
const MemorySlot &slot,
112 memref::AllocaOp::getDestructurableSlots() {
113 MemRefType memrefType = getType();
114 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
118 std::optional<DenseMap<Attribute, Type>> destructuredType =
119 destructurable.getSubelementIndexMap();
120 if (!destructuredType)
135 auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
136 for (
Attribute usedIndex : usedIndices) {
137 Type elemType = memrefType.getTypeAtIndex(usedIndex);
139 auto subAlloca = rewriter.
create<memref::AllocaOp>(getLoc(), elemPtr);
141 {subAlloca.getResult(), elemType});
147 void memref::AllocaOp::handleDestructuringComplete(
149 assert(slot.
ptr == getResult());
157 bool memref::LoadOp::loadsFrom(
const MemorySlot &slot) {
158 return getMemRef() == slot.
ptr;
161 bool memref::LoadOp::storesTo(
const MemorySlot &slot) {
return false; }
166 llvm_unreachable(
"getStored should not be called on LoadOp");
169 bool memref::LoadOp::canUsesBeRemoved(
173 if (blockingUses.size() != 1)
175 Value blockingUse = (*blockingUses.begin())->
get();
176 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
196 MemRefType memrefType) {
198 for (
auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
199 IntegerAttr coordAttr;
200 if (!
matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
203 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
204 if (!coordInt || coordInt.value() >=
static_cast<uint64_t
>(dimSize))
206 index.push_back(coordAttr);
215 if (slot.
ptr != getMemRef())
221 usedIndices.insert(index);
231 const MemorySlot &memorySlot = subslots.at(index);
233 setMemRef(memorySlot.
ptr);
234 getIndicesMutable().clear();
239 bool memref::StoreOp::loadsFrom(
const MemorySlot &slot) {
return false; }
241 bool memref::StoreOp::storesTo(
const MemorySlot &slot) {
242 return getMemRef() == slot.
ptr;
251 bool memref::StoreOp::canUsesBeRemoved(
255 if (blockingUses.size() != 1)
257 Value blockingUse = (*blockingUses.begin())->
get();
258 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
273 if (slot.
ptr != getMemRef() || getValue() == slot.
ptr)
279 usedIndices.insert(index);
289 const MemorySlot &memorySlot = subslots.at(index);
291 setMemRef(memorySlot.
ptr);
292 getIndicesMutable().clear();
303 struct MemRefDestructurableTypeExternalModel
304 :
public DestructurableTypeInterface::ExternalModel<
305 MemRefDestructurableTypeExternalModel, MemRefType> {
306 std::optional<DenseMap<Attribute, Type>>
307 getSubelementIndexMap(
Type type)
const {
308 auto memrefType = llvm::cast<MemRefType>(type);
309 constexpr int64_t maxMemrefSizeForDestructuring = 16;
310 if (!memrefType.hasStaticShape() ||
311 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312 memrefType.getNumElements() == 1)
317 memrefType.getContext(), memrefType.getShape(), [&](
Attribute index) {
318 destructured.insert({index, memrefType.getElementType()});
325 auto memrefType = llvm::cast<MemRefType>(type);
326 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
331 for (
const auto &[coordAttr, dimSize] :
332 llvm::zip(coordArrAttr, memrefType.getShape())) {
333 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335 coord.getInt() >= dimSize)
339 return memrefType.getElementType();
351 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
static MLIRContext * getContext(OpFoldResult val)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry ®istry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the corresponding subelement type.
This class represents an efficient way to signal success or failure.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.