22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
32 if (!getOperation()->getBlock()->isEntryBlock())
35 return {
MemorySlot{getResult(), getResultPtrElementType()}};
43 void LLVM::AllocaOp::handleBlockArgument(
const MemorySlot &slot,
46 for (
Operation *user : getOperation()->getUsers())
47 if (
auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48 rewriter.
create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49 declareOp.getVarInfo());
52 void LLVM::AllocaOp::handlePromotionComplete(
const MemorySlot &slot,
55 if (defaultValue && defaultValue.
use_empty())
64 Type elemType = getResultPtrElementType();
65 auto destructurable = dyn_cast<DestructurableTypeInterface>(elemType);
69 std::optional<DenseMap<Attribute, Type>> destructuredType =
70 destructurable.getSubelementIndexMap();
71 if (!destructuredType)
75 for (
Attribute index : llvm::make_first_range(destructuredType.value()))
85 assert(slot.
ptr == getResult());
87 getElemType() ? *getElemType() : getResult().getType().
getElementType();
91 auto destructurableType = cast<DestructurableTypeInterface>(elemType);
94 Type elemType = destructurableType.getTypeAtIndex(index);
95 assert(elemType &&
"used index must exist");
96 auto subAlloca = rewriter.
create<LLVM::AllocaOp>(
99 slotMap.try_emplace<
MemorySlot>(index, {subAlloca.getResult(), elemType});
105 void LLVM::AllocaOp::handleDestructuringComplete(
107 assert(slot.
ptr == getResult());
115 bool LLVM::LoadOp::loadsFrom(
const MemorySlot &slot) {
116 return getAddr() == slot.
ptr;
119 bool LLVM::LoadOp::storesTo(
const MemorySlot &slot) {
return false; }
122 llvm_unreachable(
"getStored should not be called on LoadOp");
125 bool LLVM::StoreOp::loadsFrom(
const MemorySlot &slot) {
return false; }
127 bool LLVM::StoreOp::storesTo(
const MemorySlot &slot) {
128 return getAddr() == slot.
ptr;
135 bool LLVM::LoadOp::canUsesBeRemoved(
138 if (blockingUses.size() != 1)
140 Value blockingUse = (*blockingUses.begin())->
get();
145 return blockingUse == slot.
ptr && getAddr() == slot.
ptr &&
158 bool LLVM::StoreOp::canUsesBeRemoved(
161 if (blockingUses.size() != 1)
163 Value blockingUse = (*blockingUses.begin())->
get();
167 return blockingUse == slot.
ptr && getAddr() == slot.
ptr &&
178 if (
auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
179 rewriter.
create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
180 declareOp.getVarInfo());
192 getValue().getType() == slot.
elemType);
204 newBlockingUses.push_back(&use);
208 bool LLVM::BitcastOp::canUsesBeRemoved(
219 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
230 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
241 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
252 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
263 bool LLVM::DbgValueOp::canUsesBeRemoved(
267 if (blockingUses.size() != 1)
270 return (*blockingUses.begin())->get() == getValue();
282 rewriter.
create<UndefOp>(getValue().getLoc(), getValue().getType());
292 return llvm::all_of(gepOp.getIndices(), [](
auto index) {
293 auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
294 return indexAttr && indexAttr.getValue() == 0;
298 bool LLVM::GEPOp::canUsesBeRemoved(
314 llvm::dyn_cast_if_present<IntegerAttr>(gep.getIndices()[0]);
315 return index && index.getInt() == 0;
326 Type reachedType = getResultPtrElementType();
329 mustBeSafelyUsed.emplace_back<
MemorySlot>({getResult(), reachedType});
336 auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(
getBase().getType());
342 if (!basePtrType.isOpaque())
349 Type reachedType = getResultPtrElementType();
352 auto firstLevelIndex = dyn_cast<IntegerAttr>(
getIndices()[1]);
353 if (!firstLevelIndex)
355 assert(slot.
elementPtrs.contains(firstLevelIndex));
356 if (!llvm::isa<LLVM::LLVMPointerType>(slot.
elementPtrs.at(firstLevelIndex)))
358 mustBeSafelyUsed.emplace_back<
MemorySlot>({getResult(), reachedType});
359 usedIndices.insert(firstLevelIndex);
366 IntegerAttr firstLevelIndex =
367 llvm::dyn_cast_if_present<IntegerAttr>(
getIndices()[1]);
368 const MemorySlot &newSlot = subslots.at(firstLevelIndex);
376 if (remainingIndices.empty()) {
385 newIndices.append(remainingIndices.begin(), remainingIndices.end());
386 setRawConstantIndices(newIndices);
392 getBaseMutable().assign(newSlot.
ptr);
406 template <
class MemIntr>
407 std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
411 if (memIntrLen.getBitWidth() > 64)
413 return memIntrLen.getZExtValue();
421 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
422 APInt memIntrLen = op.getLen();
423 if (memIntrLen.getBitWidth() > 64)
425 return memIntrLen.getZExtValue();
432 template <
class MemIntr>
435 if (!isa<LLVM::LLVMPointerType>(slot.
ptr.
getType()) ||
436 op.getDst() != slot.
ptr)
439 std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
447 return llvm::all_of(llvm::make_first_range(slot.
elementPtrs),
449 auto intIndex = dyn_cast<IntegerAttr>(index);
450 return intIndex && intIndex.getType() == i32;
458 bool LLVM::MemsetOp::loadsFrom(
const MemorySlot &slot) {
return false; }
460 bool LLVM::MemsetOp::storesTo(
const MemorySlot &slot) {
461 return getDst() == slot.
ptr;
468 .Case([&](IntegerType intType) ->
Value {
469 if (intType.getWidth() == 8)
472 assert(intType.getWidth() % 8 == 0);
476 uint64_t coveredBits = 8;
478 rewriter.
create<LLVM::ZExtOp>(getLoc(), intType, getVal());
479 while (coveredBits < intType.getWidth()) {
481 rewriter.
create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
483 rewriter.
create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
485 rewriter.
create<LLVM::OrOp>(getLoc(), currentValue, shifted);
493 "getStored should not be called on memset to unsupported type");
497 bool LLVM::MemsetOp::canUsesBeRemoved(
501 bool canConvertType =
503 .Case([](IntegerType intType) {
504 return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
539 .getSubelementIndexMap())
552 std::optional<DenseMap<Attribute, Type>> types =
553 slot.
elemType.
cast<DestructurableTypeInterface>().getSubelementIndexMap();
555 IntegerAttr memsetLenAttr;
556 bool successfulMatch =
557 matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
558 (void)successfulMatch;
559 assert(successfulMatch);
562 if (
auto structType = dyn_cast<LLVM::LLVMStructType>(slot.
elemType))
563 packed = structType.isPacked();
567 uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
568 uint64_t covered = 0;
569 for (
size_t i = 0; i < types->size(); i++) {
572 Type elemType = types->at(index);
573 uint64_t typeSize = dataLayout.
getTypeSize(elemType);
579 if (covered >= memsetLen)
584 if (subslots.contains(index)) {
585 uint64_t newMemsetSize =
std::min(memsetLen - covered, typeSize);
587 Value newMemsetSizeValue =
589 .
create<LLVM::ConstantOp>(
594 rewriter.
create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
595 getVal(), newMemsetSizeValue,
609 template <
class MemcpyLike>
611 return op.getSrc() == slot.
ptr;
614 template <
class MemcpyLike>
616 return op.getDst() == slot.
ptr;
619 template <
class MemcpyLike>
625 template <
class MemcpyLike>
633 if (op.getDst() == op.getSrc())
636 if (op.getIsVolatile())
643 template <
class MemcpyLike>
648 if (op.loadsFrom(slot))
649 rewriter.
create<LLVM::StoreOp>(op.
getLoc(), reachingDefinition,
654 template <
class MemcpyLike>
665 template <
class MemcpyLike>
669 if (op.getIsVolatile())
673 .getSubelementIndexMap())
684 if (op.getSrc() == slot.
ptr)
686 usedIndices.insert(index);
693 template <
class MemcpyLike>
696 Type toCpy,
bool isVolatile) {
697 Value memcpySize = rewriter.
create<LLVM::ConstantOp>(
700 rewriter.
create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
706 LLVM::MemcpyInlineOp toReplace,
Value dst,
707 Value src,
Type toCpy,
bool isVolatile) {
709 toReplace.getLen().getBitWidth());
710 rewriter.
create<LLVM::MemcpyInlineOp>(
711 toReplace.getLoc(), dst, src,
719 template <
class MemcpyLike>
724 if (subslots.empty())
729 assert((slot.
ptr == op.getDst()) != (slot.
ptr == op.getSrc()));
730 bool isDst = slot.
ptr == op.getDst();
733 size_t slotsTreated = 0;
738 Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
739 for (
size_t i = 0, e = slot.
elementPtrs.size(); i != e; i++) {
741 if (!subslots.contains(index))
743 const MemorySlot &subslot = subslots.at(index);
752 0,
static_cast<int32_t
>(
753 cast<IntegerAttr>(index).getValue().getZExtValue())};
754 Value subslotPtrInOther = rewriter.
create<LLVM::GEPOp>(
756 isDst ? op.getSrc() : op.getDst(), gepIndices);
759 createMemcpyLikeToReplace(rewriter, layout, op,
760 isDst ? subslot.
ptr : subslotPtrInOther,
761 isDst ? subslotPtrInOther : subslot.
ptr,
762 subslot.
elemType, op.getIsVolatile());
765 assert(subslots.size() == slotsTreated);
770 bool LLVM::MemcpyOp::loadsFrom(
const MemorySlot &slot) {
774 bool LLVM::MemcpyOp::storesTo(
const MemorySlot &slot) {
783 bool LLVM::MemcpyOp::canUsesBeRemoved(
813 bool LLVM::MemcpyInlineOp::loadsFrom(
const MemorySlot &slot) {
817 bool LLVM::MemcpyInlineOp::storesTo(
const MemorySlot &slot) {
826 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
844 bool LLVM::MemcpyInlineOp::canRewire(
858 bool LLVM::MemmoveOp::loadsFrom(
const MemorySlot &slot) {
862 bool LLVM::MemmoveOp::storesTo(
const MemorySlot &slot) {
871 bool LLVM::MemmoveOp::canUsesBeRemoved(
905 std::optional<DenseMap<Attribute, Type>>
915 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
916 if (!indexAttr || !indexAttr.getType().isInteger(32))
918 int32_t indexInt = indexAttr.getInt();
920 if (indexInt < 0 || body.size() <=
static_cast<uint32_t
>(indexInt))
922 return body[indexInt];
925 std::optional<DenseMap<Attribute, Type>>
926 LLVM::LLVMArrayType::getSubelementIndexMap()
const {
927 constexpr
size_t maxArraySizeForDestructuring = 16;
934 for (int32_t index = 0; index < numElements; ++index)
935 destructured.insert({IntegerAttr::get(i32, index), getElementType()});
939 Type LLVM::LLVMArrayType::getTypeAtIndex(
Attribute index)
const {
940 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
941 if (!indexAttr || !indexAttr.getType().isInteger(32))
943 int32_t indexInt = indexAttr.getInt();
944 if (indexInt < 0 ||
getNumElements() <=
static_cast<uint32_t
>(indexInt))
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed)
static DeletionKind memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, RewriterBase &rewriter, Value reachingDefinition)
static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot, DataLayout &dataLayout)
Returns whether one can be sure the memory intrinsic does not write outside of the bounds of the give...
static bool areAllIndicesI32(const DestructurableMemorySlot &slot)
Checks whether all indices are i32.
static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot)
static bool isFirstIndexZero(LLVM::GEPOp gep)
static bool forwardToUsers(Operation *op, SmallVectorImpl< OpOperand * > &newBlockingUses)
Conditions the deletion of the operation to the removal of all its uses.
static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot)
static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, SmallPtrSetImpl< Attribute > &usedIndices, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed)
static bool hasAllZeroIndices(LLVM::GEPOp gepOp)
static bool memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, SmallVectorImpl< OpOperand * > &newBlockingUses)
static DeletionKind memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, DenseMap< Attribute, MemorySlot > &subslots, RewriterBase &rewriter)
Rewires a memcpy-like operation.
static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, RewriterBase &rewriter)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static int64_t getNumElements(ShapedType type)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
unsigned getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Type getTypeAtIndex(Attribute index)
Returns which type is stored at a given integer index within the struct.
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
std::optional< DenseMap< Attribute, Type > > getSubelementIndexMap()
Destructs the struct into its indexed field types.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the type of the pointer that will be generated to access the ...
This class represents an efficient way to signal success or failure.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.