11 #include "llvm/ADT/TypeSwitch.h"
15 #define GEN_PASS_DEF_LLVMTYPECONSISTENCY
16 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
31 auto defOp = dyn_cast_or_null<GetResultPtrElementType>(addr.
getDefiningOp());
35 Type elemType = defOp.getResultPtrElementType();
39 if (elemType == expectedType)
47 return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
48 !isa<LLVMStructType, LLVMArrayType>(rhs) &&
59 auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
63 Type subelementType = destructurable.getTypeAtIndex(
66 return subelementType;
74 template <
class MemOp>
77 PatternRewriter::InsertionGuard guard(rewriter);
84 op.getAddr(), firstTypeIndices);
87 [&]() { op.getAddrMutable().assign(properPtr); });
93 PatternRewriter::InsertionGuard guard(rewriter);
95 Type inconsistentElementType =
97 if (!inconsistentElementType)
106 insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
110 if (firstType != load.getResult().getType()) {
112 BitcastOp bitcast = rewriter.
create<BitcastOp>(
113 load->getLoc(), load.getResult().getType(), load.getResult());
115 [&]() { load.getResult().setType(firstType); });
126 PatternRewriter::InsertionGuard guard(rewriter);
128 Type inconsistentElementType =
130 if (!inconsistentElementType)
142 insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
145 store, [&]() { store.getValueMutable().assign(store.getValue()); });
160 for (
auto index : gep.getIndices()) {
161 IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
164 indices.push_back(indexInt.getInt());
167 uint64_t offset = indices[0] * layout.
getTypeSize(gep.getElemType());
169 Type currentType = gep.getElemType();
170 for (uint32_t index : llvm::drop_begin(indices)) {
173 .Case([&](LLVMArrayType arrayType) {
174 if (arrayType.getNumElements() <= index)
176 offset += index * layout.
getTypeSize(arrayType.getElementType());
177 currentType = arrayType.getElementType();
182 if (body.size() <= index)
184 for (uint32_t i = 0; i < index; i++) {
186 offset = llvm::alignTo(offset,
190 currentType = body[index];
193 .Default([](
Type) {
return true; });
210 uint64_t rootIndex = offset / baseSize;
213 equivalentIndicesOut.push_back(rootIndex);
215 uint64_t distanceToStart = rootIndex * baseSize;
218 auto isWithinCurrentType = [&](
Type currentType) {
219 return offset < distanceToStart + layout.
getTypeSize(currentType);
223 Type currentType = base;
224 while (distanceToStart < offset) {
228 assert(isWithinCurrentType(currentType));
232 .Case([&](LLVMArrayType arrayType) {
236 uint64_t index = (offset - distanceToStart) / elemSize;
237 equivalentIndicesOut.push_back(index);
238 distanceToStart += index * elemSize;
243 currentType = arrayType.getElementType();
255 for (
Type elem : body) {
258 distanceToStart = llvm::alignTo(
259 distanceToStart, layout.getTypeABIAlignment(elem));
261 if (offset < distanceToStart)
265 if (offset < distanceToStart + elemSize) {
268 equivalentIndicesOut.push_back(index);
275 distanceToStart += elemSize;
303 if (!gep.getElemType())
306 std::optional<Type> maybeBaseType = gep.getElemType();
309 Type baseType = *maybeBaseType;
338 newIndices, gep.getInbounds());
346 class DestructurableTypeRange
347 :
public llvm::indexed_accessor_range<DestructurableTypeRange,
348 DestructurableTypeInterface, Type,
351 using Base = llvm::indexed_accessor_range<
352 DestructurableTypeRange, DestructurableTypeInterface,
Type,
Type *,
Type>;
359 explicit DestructurableTypeRange(DestructurableTypeInterface base)
360 : Base(base, 0, [&]() -> ptrdiff_t {
361 return
TypeSwitch<DestructurableTypeInterface, ptrdiff_t>(base)
363 return structType.getBody().size();
365 .Case([](LLVMArrayType arrayType) {
366 return arrayType.getNumElements();
368 .Default([](
auto) -> ptrdiff_t {
370 "Only LLVMStructType or LLVMArrayType supported");
375 bool isPacked()
const {
376 if (
auto structType = dyn_cast<LLVMStructType>(
getBase()))
377 return structType.isPacked();
382 static Type dereference(DestructurableTypeInterface base, ptrdiff_t index) {
385 Type result = base.getTypeAtIndex(
387 assert(result &&
"Should always succeed");
402 DestructurableTypeInterface destructurableType,
403 unsigned storeSize,
unsigned storeOffset) {
404 DestructurableTypeRange destructurableTypeRange(destructurableType);
406 unsigned currentOffset = 0;
407 for (; !destructurableTypeRange.empty();
408 destructurableTypeRange = destructurableTypeRange.drop_front()) {
409 Type type = destructurableTypeRange.front();
410 if (!destructurableTypeRange.isPacked()) {
412 currentOffset = llvm::alignTo(currentOffset, alignment);
418 if (currentOffset == storeOffset)
421 assert(currentOffset < storeOffset &&
422 "storeOffset should cleanly point into an immediate field");
427 size_t exclusiveEnd = 0;
428 for (; exclusiveEnd < destructurableTypeRange.size() && storeSize > 0;
430 if (!destructurableTypeRange.isPacked()) {
434 if (!llvm::isAligned(llvm::Align(alignment), currentOffset))
439 dataLayout.
getTypeSize(destructurableTypeRange[exclusiveEnd]);
440 if (fieldSize > storeSize) {
444 auto subAggregate = dyn_cast<DestructurableTypeInterface>(
445 destructurableTypeRange[exclusiveEnd]);
455 return destructurableTypeRange.take_front(exclusiveEnd + 1);
457 currentOffset += fieldSize;
458 storeSize -= fieldSize;
465 return destructurableTypeRange.take_front(exclusiveEnd);
474 unsigned storeOffset) {
475 VectorType vectorType = value.getType();
476 unsigned elementSize = dataLayout.
getTypeSize(vectorType.getElementType());
479 for (
size_t index : llvm::seq<size_t>(0, vectorType.getNumElements())) {
482 auto extractOp = rewriter.
create<ExtractElementOp>(loc, value, pos);
486 auto gepOp = rewriter.
create<GEPOp>(
490 rewriter.
create<StoreOp>(loc, extractOp, gepOp);
499 Value value,
unsigned storeSize,
500 unsigned storeOffset,
501 DestructurableTypeRange writtenToFields) {
502 unsigned currentOffset = storeOffset;
503 for (
Type type : writtenToFields) {
508 auto pos = rewriter.
create<ConstantOp>(
510 (currentOffset - storeOffset) * 8));
512 auto shrOp = rewriter.
create<LShrOp>(loc, value, pos);
518 IntegerType fieldIntType =
520 Value valueToStore = rewriter.
create<TruncOp>(loc, fieldIntType, shrOp);
527 rewriter.
create<StoreOp>(loc, valueToStore, gepOp);
531 currentOffset += fieldSize;
532 storeSize -= fieldSize;
538 Type sourceType = store.getValue().getType();
539 if (!isa<IntegerType, VectorType>(sourceType)) {
552 unsigned storeSize = dataLayout.getTypeSize(sourceType);
554 Value address = store.getAddr();
560 if (gepOp.getIndices().size() != 2 ||
568 if (storeSize > dataLayout.getTypeSize(gepOp.getResultPtrElementType())) {
569 std::optional<uint64_t> byteOffset =
gepToByteOffset(dataLayout, gepOp);
573 offset = *byteOffset;
574 typeHint = gepOp.getElemType();
575 address = gepOp.getBase();
579 auto destructurableType = typeHint.
dyn_cast<DestructurableTypeInterface>();
580 if (!destructurableType)
585 if (
failed(writtenToElements))
588 if (writtenToElements->size() <= 1) {
594 if (isa<IntegerType>(sourceType)) {
596 store.getValue(), storeSize, offset, *writtenToElements);
603 if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorSplitSize)
617 Type sourceType = store.getValue().getType();
629 rewriter.
create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
631 store, [&] { store.getValueMutable().assign(bitcastOp); });
638 if (
succeeded(typeHint) || gepOp.getIndices().size() <= 2) {
645 if (
auto integerAttr = dyn_cast<IntegerAttr>(index))
646 return integerAttr.getValue().getSExtValue();
647 return cast<Value>(index);
652 auto splitIter = std::next(indices.
begin(), 2);
655 auto subGepOp = rewriter.
create<GEPOp>(
656 gepOp.getLoc(), gepOp.getType(), gepOp.getElemType(), gepOp.getBase(),
657 llvm::map_to_vector(llvm::make_range(indices.
begin(), splitIter),
659 gepOp.getInbounds());
666 llvm::transform(llvm::make_range(splitIter, indices.
end()),
667 std::back_inserter(newIndices), indexToGEPArg);
669 subGepOp.getResultPtrElementType(),
670 subGepOp, newIndices, gepOp.getInbounds());
679 struct LLVMTypeConsistencyPass
680 :
public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
681 void runOnOperation()
override {
699 return std::make_unique<LLVMTypeConsistencyPass>();
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset, SmallVectorImpl< GEPArg > &equivalentIndicesOut)
Fills in equivalentIndicesOut with GEP indices that would be equivalent to offsetting a pointer by of...
static Type isElementTypeInconsistent(Value addr, Type expectedType)
Checks that a pointer value has a pointee type hint consistent with the expected type.
static void splitIntegerStore(const DataLayout &dataLayout, Location loc, RewriterBase &rewriter, Value address, Value value, unsigned storeSize, unsigned storeOffset, DestructurableTypeRange writtenToFields)
Splits a store of the integer value into address at storeOffset into multiple stores to each 'written...
static std::optional< uint64_t > gepToByteOffset(DataLayout &layout, GEPOp gep)
Returns the amount of bytes the provided GEP elements will offset the pointer by.
static FailureOr< Type > getRequiredConsistentGEPType(GEPOp gep)
Returns the consistent type for the GEP if the GEP is not type-consistent.
static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter, Type elemType)
Extracts a pointer to the first field of an elemType from the address pointer of the provided MemOp,...
static Type getFirstSubelementType(Type type)
Gets the type of the first subelement of type if type is destructurable, nullptr otherwise.
static FailureOr< DestructurableTypeRange > getWrittenToFields(const DataLayout &dataLayout, DestructurableTypeInterface destructurableType, unsigned storeSize, unsigned storeOffset)
Returns the list of elements of destructurableType that are written to by a store operation writing s...
static void splitVectorStore(const DataLayout &dataLayout, Location loc, RewriterBase &rewriter, Value address, TypedValue< VectorType > value, unsigned storeOffset)
Splits a store of the vector value into address at storeOffset into multiple stores of each element w...
static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs)
Checks that two types are the same or can be bitcast into one another.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
Transforms uses of pointers to a whole struct to uses of pointers to the first element of a struct.
LogicalResult matchAndRewrite(User user, PatternRewriter &rewriter) const override
Transforms type-inconsistent stores, aka stores where the type hint of the address contradicts the va...
Canonicalizes GEPs of which the base type and the pointer's type hint do not match.
LogicalResult matchAndRewrite(GEPOp gep, PatternRewriter &rewriter) const override
Class used for building a 'llvm.getelementptr'.
Class used for convenient access and iteration over GEP indices.
iterator begin() const
Returns the begin iterator, iterating over all GEP indices.
std::conditional_t< std::is_base_of< Attribute, llvm::detail::ValueOfRange< DynamicRange > >::value, Attribute, PointerUnion< IntegerAttr, llvm::detail::ValueOfRange< DynamicRange > >> value_type
Return type of 'operator[]' and the iterators 'operator*'.
iterator end() const
Returns the end iterator, iterating over all GEP indices.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
bool isPacked() const
Checks if a struct is packed.
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Splits GEPs with more than two indices into multiple GEPs with exactly two indices.
Splits stores which write into multiple adjacent elements of an aggregate through a pointer.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::unique_ptr< Pass > createTypeConsistencyPass()
Creates a pass that adjusts operations operating on pointers so they interpret pointee types as consi...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.