24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/ErrorHandling.h"
39 for (
size_t i = 0; i < shape.size(); ++i) {
41 if (index[i] < shape[i])
50 template <
typename CallableT>
57 for (int64_t dim : shapeIter)
68 return llvm::isa<MemRefType>(type) ||
73 MemRefType type = getType();
76 if (!type.hasStaticShape())
79 if (type.getNumElements() != 1)
82 return {
MemorySlot{getResult(), type.getElementType()}};
90 .Case([&](MemRefType t) {
91 return rewriter.
create<memref::AllocaOp>(getLoc(), t);
93 .Default([&](
Type t) {
94 return rewriter.
create<arith::ConstantOp>(getLoc(), t,
99 void memref::AllocaOp::handlePromotionComplete(
const MemorySlot &slot,
107 void memref::AllocaOp::handleBlockArgument(
const MemorySlot &slot,
112 memref::AllocaOp::getDestructurableSlots() {
113 MemRefType memrefType = getType();
114 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
118 std::optional<DenseMap<Attribute, Type>> destructuredType =
119 destructurable.getSubelementIndexMap();
120 if (!destructuredType)
124 for (
auto const &[index, type] : *destructuredType)
138 auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
139 for (
Attribute usedIndex : usedIndices) {
140 Type elemType = memrefType.getTypeAtIndex(usedIndex);
142 auto subAlloca = rewriter.
create<memref::AllocaOp>(getLoc(), elemPtr);
144 {subAlloca.getResult(), elemType});
150 void memref::AllocaOp::handleDestructuringComplete(
152 assert(slot.
ptr == getResult());
160 bool memref::LoadOp::loadsFrom(
const MemorySlot &slot) {
161 return getMemRef() == slot.
ptr;
164 bool memref::LoadOp::storesTo(
const MemorySlot &slot) {
return false; }
168 llvm_unreachable(
"getStored should not be called on LoadOp");
171 bool memref::LoadOp::canUsesBeRemoved(
174 if (blockingUses.size() != 1)
176 Value blockingUse = (*blockingUses.begin())->
get();
177 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
196 MemRefType memrefType) {
198 for (
auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
199 IntegerAttr coordAttr;
200 if (!
matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
203 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
204 if (!coordInt || coordInt.value() >=
static_cast<uint64_t
>(dimSize))
206 index.push_back(coordAttr);
214 if (slot.
ptr != getMemRef())
220 usedIndices.insert(index);
229 const MemorySlot &memorySlot = subslots.at(index);
231 setMemRef(memorySlot.
ptr);
232 getIndicesMutable().clear();
237 bool memref::StoreOp::loadsFrom(
const MemorySlot &slot) {
return false; }
239 bool memref::StoreOp::storesTo(
const MemorySlot &slot) {
240 return getMemRef() == slot.
ptr;
248 bool memref::StoreOp::canUsesBeRemoved(
251 if (blockingUses.size() != 1)
253 Value blockingUse = (*blockingUses.begin())->
get();
254 return blockingUse == slot.
ptr && getMemRef() == slot.
ptr &&
267 if (slot.
ptr != getMemRef() || getValue() == slot.
ptr)
273 usedIndices.insert(index);
282 const MemorySlot &memorySlot = subslots.at(index);
284 setMemRef(memorySlot.
ptr);
285 getIndicesMutable().clear();
296 struct MemRefDestructurableTypeExternalModel
297 :
public DestructurableTypeInterface::ExternalModel<
298 MemRefDestructurableTypeExternalModel, MemRefType> {
299 std::optional<DenseMap<Attribute, Type>>
300 getSubelementIndexMap(
Type type)
const {
301 auto memrefType = llvm::cast<MemRefType>(type);
302 constexpr int64_t maxMemrefSizeForDestructuring = 16;
303 if (!memrefType.hasStaticShape() ||
304 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
305 memrefType.getNumElements() == 1)
310 memrefType.getContext(), memrefType.getShape(), [&](
Attribute index) {
311 destructured.insert({index, memrefType.getElementType()});
318 auto memrefType = llvm::cast<MemRefType>(type);
319 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
320 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
324 for (
const auto &[coordAttr, dimSize] :
325 llvm::zip(coordArrAttr, memrefType.getShape())) {
326 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
327 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
328 coord.getInt() >= dimSize)
332 return memrefType.getElementType();
344 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
static MLIRContext * getContext(OpFoldResult val)
static bool isSupportedElementType(Type type)
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, ValueRange indices, MemRefType memrefType)
Returns the index of a memref in attribute form, given its indices.
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef< int64_t > shape, CallableT &&walker)
Calls walker for each index within a tensor of a given shape, providing the index as an array attribu...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
void registerMemorySlotExternalModels(DialectRegistry ®istry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the type of the pointer that will be generated to access the ...
This class represents an efficient way to signal success or failure.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.