11 #include "llvm/ADT/TypeSwitch.h"
15 #define GEN_PASS_DEF_LLVMTYPECONSISTENCY
16 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
31 auto defOp = dyn_cast_or_null<GetResultPtrElementType>(addr.
getDefiningOp());
35 Type elemType = defOp.getResultPtrElementType();
39 if (elemType == expectedType)
55 for (
auto index : gep.getIndices()) {
56 IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
59 int32_t gepIndex = indexInt.getInt();
62 indices.push_back(
static_cast<uint32_t
>(gepIndex));
65 uint64_t offset = indices[0] * layout.
getTypeSize(gep.getElemType());
67 Type currentType = gep.getElemType();
68 for (uint32_t index : llvm::drop_begin(indices)) {
71 .Case([&](LLVMArrayType arrayType) {
72 if (arrayType.getNumElements() <= index)
74 offset += index * layout.
getTypeSize(arrayType.getElementType());
75 currentType = arrayType.getElementType();
80 if (body.size() <= index)
82 for (uint32_t i = 0; i < index; i++) {
84 offset = llvm::alignTo(offset,
88 currentType = body[index];
91 .Default([](
Type) {
return true; });
108 uint64_t rootIndex = offset / baseSize;
111 equivalentIndicesOut.push_back(rootIndex);
113 uint64_t distanceToStart = rootIndex * baseSize;
116 auto isWithinCurrentType = [&](
Type currentType) {
117 return offset < distanceToStart + layout.
getTypeSize(currentType);
121 Type currentType = base;
122 while (distanceToStart < offset) {
126 assert(isWithinCurrentType(currentType));
130 .Case([&](LLVMArrayType arrayType) {
134 uint64_t index = (offset - distanceToStart) / elemSize;
135 equivalentIndicesOut.push_back(index);
136 distanceToStart += index * elemSize;
141 currentType = arrayType.getElementType();
153 for (
Type elem : body) {
156 distanceToStart = llvm::alignTo(
157 distanceToStart, layout.getTypeABIAlignment(elem));
159 if (offset < distanceToStart)
163 if (offset < distanceToStart + elemSize) {
166 equivalentIndicesOut.push_back(index);
173 distanceToStart += elemSize;
201 if (!gep.getElemType())
204 std::optional<Type> maybeBaseType = gep.getElemType();
207 Type baseType = *maybeBaseType;
236 newIndices, gep.getInbounds());
244 class DestructurableTypeRange
245 :
public llvm::indexed_accessor_range<DestructurableTypeRange,
246 DestructurableTypeInterface, Type,
249 using Base = llvm::indexed_accessor_range<
250 DestructurableTypeRange, DestructurableTypeInterface,
Type,
Type *,
Type>;
257 explicit DestructurableTypeRange(DestructurableTypeInterface base)
258 : Base(base, 0, [&]() -> ptrdiff_t {
259 return
TypeSwitch<DestructurableTypeInterface, ptrdiff_t>(base)
261 return structType.getBody().size();
263 .Case([](LLVMArrayType arrayType) {
264 return arrayType.getNumElements();
266 .Default([](
auto) -> ptrdiff_t {
268 "Only LLVMStructType or LLVMArrayType supported");
273 bool isPacked()
const {
274 if (
auto structType = dyn_cast<LLVMStructType>(
getBase()))
275 return structType.isPacked();
280 static Type dereference(DestructurableTypeInterface base, ptrdiff_t index) {
283 Type result = base.getTypeAtIndex(
285 assert(result &&
"Should always succeed");
300 DestructurableTypeInterface destructurableType,
301 unsigned storeSize,
unsigned storeOffset) {
302 DestructurableTypeRange destructurableTypeRange(destructurableType);
304 unsigned currentOffset = 0;
305 for (; !destructurableTypeRange.empty();
306 destructurableTypeRange = destructurableTypeRange.drop_front()) {
307 Type type = destructurableTypeRange.front();
308 if (!destructurableTypeRange.isPacked()) {
310 currentOffset = llvm::alignTo(currentOffset, alignment);
316 if (currentOffset == storeOffset)
319 assert(currentOffset < storeOffset &&
320 "storeOffset should cleanly point into an immediate field");
325 size_t exclusiveEnd = 0;
326 for (; exclusiveEnd < destructurableTypeRange.size() && storeSize > 0;
328 if (!destructurableTypeRange.isPacked()) {
332 if (!llvm::isAligned(llvm::Align(alignment), currentOffset))
337 dataLayout.
getTypeSize(destructurableTypeRange[exclusiveEnd]);
338 if (fieldSize > storeSize) {
342 auto subAggregate = dyn_cast<DestructurableTypeInterface>(
343 destructurableTypeRange[exclusiveEnd]);
353 return destructurableTypeRange.take_front(exclusiveEnd + 1);
355 currentOffset += fieldSize;
356 storeSize -= fieldSize;
363 return destructurableTypeRange.take_front(exclusiveEnd);
372 unsigned storeOffset) {
373 VectorType vectorType = value.getType();
374 unsigned elementSize = dataLayout.
getTypeSize(vectorType.getElementType());
377 for (
size_t index : llvm::seq<size_t>(0, vectorType.getNumElements())) {
380 auto extractOp = rewriter.
create<ExtractElementOp>(loc, value, pos);
384 auto gepOp = rewriter.
create<GEPOp>(
387 static_cast<int32_t
>(storeOffset + index * elementSize)});
389 rewriter.
create<StoreOp>(loc, extractOp, gepOp);
398 Value value,
unsigned storeSize,
399 unsigned storeOffset,
400 DestructurableTypeRange writtenToFields) {
401 unsigned currentOffset = storeOffset;
402 for (
Type type : writtenToFields) {
407 auto pos = rewriter.
create<ConstantOp>(
409 (currentOffset - storeOffset) * 8));
411 auto shrOp = rewriter.
create<LShrOp>(loc, value, pos);
417 IntegerType fieldIntType =
419 Value valueToStore = rewriter.
create<TruncOp>(loc, fieldIntType, shrOp);
423 auto gepOp = rewriter.
create<GEPOp>(
426 rewriter.
create<StoreOp>(loc, valueToStore, gepOp);
430 currentOffset += fieldSize;
431 storeSize -= fieldSize;
437 Type sourceType = store.getValue().getType();
438 if (!isa<IntegerType, VectorType>(sourceType)) {
451 unsigned storeSize = dataLayout.getTypeSize(sourceType);
453 Value address = store.getAddr();
459 if (gepOp.getIndices().size() != 2 ||
467 if (storeSize > dataLayout.getTypeSize(gepOp.getResultPtrElementType())) {
468 std::optional<uint64_t> byteOffset =
gepToByteOffset(dataLayout, gepOp);
472 offset = *byteOffset;
473 typeHint = gepOp.getElemType();
474 address = gepOp.getBase();
478 auto destructurableType = dyn_cast<DestructurableTypeInterface>(typeHint);
479 if (!destructurableType)
484 if (
failed(writtenToElements))
487 if (writtenToElements->size() <= 1) {
493 if (isa<IntegerType>(sourceType)) {
495 store.getValue(), storeSize, offset, *writtenToElements);
502 if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorSplitSize)
517 if (
succeeded(typeHint) || gepOp.getIndices().size() <= 2) {
524 if (
auto integerAttr = dyn_cast<IntegerAttr>(index))
525 return integerAttr.getValue().getSExtValue();
526 return cast<Value>(index);
531 auto splitIter = std::next(indices.
begin(), 2);
534 auto subGepOp = rewriter.
create<GEPOp>(
535 gepOp.getLoc(), gepOp.getType(), gepOp.getElemType(), gepOp.getBase(),
536 llvm::map_to_vector(llvm::make_range(indices.
begin(), splitIter),
538 gepOp.getInbounds());
545 llvm::transform(llvm::make_range(splitIter, indices.
end()),
546 std::back_inserter(newIndices), indexToGEPArg);
548 subGepOp.getResultPtrElementType(),
549 subGepOp, newIndices, gepOp.getInbounds());
558 struct LLVMTypeConsistencyPass
559 :
public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
560 void runOnOperation()
override {
574 return std::make_unique<LLVMTypeConsistencyPass>();
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset, SmallVectorImpl< GEPArg > &equivalentIndicesOut)
Fills in equivalentIndicesOut with GEP indices that would be equivalent to offsetting a pointer by of...
static Type isElementTypeInconsistent(Value addr, Type expectedType)
Checks that a pointer value has a pointee type hint consistent with the expected type.
static void splitIntegerStore(const DataLayout &dataLayout, Location loc, RewriterBase &rewriter, Value address, Value value, unsigned storeSize, unsigned storeOffset, DestructurableTypeRange writtenToFields)
Splits a store of the integer value into address at storeOffset into multiple stores to each 'written...
static std::optional< uint64_t > gepToByteOffset(DataLayout &layout, GEPOp gep)
Returns the amount of bytes the provided GEP elements will offset the pointer by.
static FailureOr< Type > getRequiredConsistentGEPType(GEPOp gep)
Returns the consistent type for the GEP if the GEP is not type-consistent.
static FailureOr< DestructurableTypeRange > getWrittenToFields(const DataLayout &dataLayout, DestructurableTypeInterface destructurableType, unsigned storeSize, unsigned storeOffset)
Returns the list of elements of destructurableType that are written to by a store operation writing s...
static void splitVectorStore(const DataLayout &dataLayout, Location loc, RewriterBase &rewriter, Value address, TypedValue< VectorType > value, unsigned storeOffset)
Splits a store of the vector value into address at storeOffset into multiple stores of each element w...
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
Canonicalizes GEPs of which the base type and the pointer's type hint do not match.
LogicalResult matchAndRewrite(GEPOp gep, PatternRewriter &rewriter) const override
Class used for building a 'llvm.getelementptr'.
Class used for convenient access and iteration over GEP indices.
iterator begin() const
Returns the begin iterator, iterating over all GEP indices.
std::conditional_t< std::is_base_of< Attribute, llvm::detail::ValueOfRange< DynamicRange > >::value, Attribute, PointerUnion< IntegerAttr, llvm::detail::ValueOfRange< DynamicRange > >> value_type
Return type of 'operator[]' and the iterators 'operator*'.
iterator end() const
Returns the end iterator, iterating over all GEP indices.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
bool isPacked() const
Checks if a struct is packed.
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Splits GEPs with more than two indices into multiple GEPs with exactly two indices.
Splits stores which write into multiple adjacent elements of an aggregate through a pointer.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::unique_ptr< Pass > createTypeConsistencyPass()
Creates a pass that adjusts operations operating on pointers so they interpret pointee types as consi...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.