MLIR  18.0.0git
TypeConsistency.cpp
Go to the documentation of this file.
1 //===- TypeConsistency.cpp - Rewrites to improve type consistency ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "llvm/ADT/TypeSwitch.h"
12 
13 namespace mlir {
14 namespace LLVM {
15 #define GEN_PASS_DEF_LLVMTYPECONSISTENCY
16 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
17 } // namespace LLVM
18 } // namespace mlir
19 
20 using namespace mlir;
21 using namespace LLVM;
22 
23 //===----------------------------------------------------------------------===//
24 // Utils
25 //===----------------------------------------------------------------------===//
26 
27 /// Checks that a pointer value has a pointee type hint consistent with the
28 /// expected type. Returns the type it actually hints to if it differs, or
29 /// nullptr if the type is consistent or impossible to analyze.
30 static Type isElementTypeInconsistent(Value addr, Type expectedType) {
31  auto defOp = dyn_cast_or_null<GetResultPtrElementType>(addr.getDefiningOp());
32  if (!defOp)
33  return nullptr;
34 
35  Type elemType = defOp.getResultPtrElementType();
36  if (!elemType)
37  return nullptr;
38 
39  if (elemType == expectedType)
40  return nullptr;
41 
42  return elemType;
43 }
44 
45 /// Checks that two types are the same or can be bitcast into one another.
46 static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
47  return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
48  !isa<LLVMStructType, LLVMArrayType>(rhs) &&
49  layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // AddFieldGetterToStructDirectUse
54 //===----------------------------------------------------------------------===//
55 
56 /// Gets the type of the first subelement of `type` if `type` is destructurable,
57 /// nullptr otherwise.
59  auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
60  if (!destructurable)
61  return nullptr;
62 
63  Type subelementType = destructurable.getTypeAtIndex(
65  if (subelementType)
66  return subelementType;
67 
68  return nullptr;
69 }
70 
71 /// Extracts a pointer to the first field of an `elemType` from the address
72 /// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
73 /// instead.
74 template <class MemOp>
75 static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
76  Type elemType) {
77  PatternRewriter::InsertionGuard guard(rewriter);
78 
79  rewriter.setInsertionPointAfterValue(op.getAddr());
80  SmallVector<GEPArg> firstTypeIndices{0, 0};
81 
82  Value properPtr = rewriter.create<GEPOp>(
83  op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
84  op.getAddr(), firstTypeIndices);
85 
86  rewriter.updateRootInPlace(op,
87  [&]() { op.getAddrMutable().assign(properPtr); });
88 }
89 
90 template <>
92  LoadOp load, PatternRewriter &rewriter) const {
93  PatternRewriter::InsertionGuard guard(rewriter);
94 
95  Type inconsistentElementType =
96  isElementTypeInconsistent(load.getAddr(), load.getType());
97  if (!inconsistentElementType)
98  return failure();
99  Type firstType = getFirstSubelementType(inconsistentElementType);
100  if (!firstType)
101  return failure();
102  DataLayout layout = DataLayout::closest(load);
103  if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
104  return failure();
105 
106  insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
107 
108  // If the load does not use the first type but a type that can be casted from
109  // it, add a bitcast and change the load type.
110  if (firstType != load.getResult().getType()) {
111  rewriter.setInsertionPointAfterValue(load.getResult());
112  BitcastOp bitcast = rewriter.create<BitcastOp>(
113  load->getLoc(), load.getResult().getType(), load.getResult());
114  rewriter.updateRootInPlace(load,
115  [&]() { load.getResult().setType(firstType); });
116  rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
117  bitcast);
118  }
119 
120  return success();
121 }
122 
123 template <>
125  StoreOp store, PatternRewriter &rewriter) const {
126  PatternRewriter::InsertionGuard guard(rewriter);
127 
128  Type inconsistentElementType =
129  isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
130  if (!inconsistentElementType)
131  return failure();
132  Type firstType = getFirstSubelementType(inconsistentElementType);
133  if (!firstType)
134  return failure();
135 
136  DataLayout layout = DataLayout::closest(store);
137  // Check that the first field has the right type or can at least be bitcast
138  // to the right type.
139  if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
140  return failure();
141 
142  insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
143 
144  rewriter.updateRootInPlace(
145  store, [&]() { store.getValueMutable().assign(store.getValue()); });
146 
147  return success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // CanonicalizeAlignedGep
152 //===----------------------------------------------------------------------===//
153 
154 /// Returns the amount of bytes the provided GEP elements will offset the
155 /// pointer by. Returns nullopt if the offset could not be computed.
156 static std::optional<uint64_t> gepToByteOffset(DataLayout &layout, GEPOp gep) {
157 
158  SmallVector<uint32_t> indices;
159  // Ensures all indices are static and fetches them.
160  for (auto index : gep.getIndices()) {
161  IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
162  if (!indexInt)
163  return std::nullopt;
164  indices.push_back(indexInt.getInt());
165  }
166 
167  uint64_t offset = indices[0] * layout.getTypeSize(gep.getElemType());
168 
169  Type currentType = gep.getElemType();
170  for (uint32_t index : llvm::drop_begin(indices)) {
171  bool shouldCancel =
172  TypeSwitch<Type, bool>(currentType)
173  .Case([&](LLVMArrayType arrayType) {
174  if (arrayType.getNumElements() <= index)
175  return true;
176  offset += index * layout.getTypeSize(arrayType.getElementType());
177  currentType = arrayType.getElementType();
178  return false;
179  })
180  .Case([&](LLVMStructType structType) {
181  ArrayRef<Type> body = structType.getBody();
182  if (body.size() <= index)
183  return true;
184  for (uint32_t i = 0; i < index; i++) {
185  if (!structType.isPacked())
186  offset = llvm::alignTo(offset,
187  layout.getTypeABIAlignment(body[i]));
188  offset += layout.getTypeSize(body[i]);
189  }
190  currentType = body[index];
191  return false;
192  })
193  .Default([](Type) { return true; });
194 
195  if (shouldCancel)
196  return std::nullopt;
197  }
198 
199  return offset;
200 }
201 
202 /// Fills in `equivalentIndicesOut` with GEP indices that would be equivalent to
203 /// offsetting a pointer by `offset` bytes, assuming the GEP has `base` as base
204 /// type.
205 static LogicalResult
206 findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset,
207  SmallVectorImpl<GEPArg> &equivalentIndicesOut) {
208 
209  uint64_t baseSize = layout.getTypeSize(base);
210  uint64_t rootIndex = offset / baseSize;
211  if (rootIndex > std::numeric_limits<uint32_t>::max())
212  return failure();
213  equivalentIndicesOut.push_back(rootIndex);
214 
215  uint64_t distanceToStart = rootIndex * baseSize;
216 
217 #ifndef NDEBUG
218  auto isWithinCurrentType = [&](Type currentType) {
219  return offset < distanceToStart + layout.getTypeSize(currentType);
220  };
221 #endif
222 
223  Type currentType = base;
224  while (distanceToStart < offset) {
225  // While an index that does not perfectly align with offset has not been
226  // reached...
227 
228  assert(isWithinCurrentType(currentType));
229 
230  bool shouldCancel =
231  TypeSwitch<Type, bool>(currentType)
232  .Case([&](LLVMArrayType arrayType) {
233  // Find which element of the array contains the offset.
234  uint64_t elemSize =
235  layout.getTypeSize(arrayType.getElementType());
236  uint64_t index = (offset - distanceToStart) / elemSize;
237  equivalentIndicesOut.push_back(index);
238  distanceToStart += index * elemSize;
239 
240  // Then, try to find where in the element the offset is. If the
241  // offset is exactly the beginning of the element, the loop is
242  // complete.
243  currentType = arrayType.getElementType();
244 
245  // Only continue if the element in question can be indexed using
246  // an i32.
247  return index > std::numeric_limits<uint32_t>::max();
248  })
249  .Case([&](LLVMStructType structType) {
250  ArrayRef<Type> body = structType.getBody();
251  uint32_t index = 0;
252 
253  // Walk over the elements of the struct to find in which of them
254  // the offset is.
255  for (Type elem : body) {
256  uint64_t elemSize = layout.getTypeSize(elem);
257  if (!structType.isPacked()) {
258  distanceToStart = llvm::alignTo(
259  distanceToStart, layout.getTypeABIAlignment(elem));
260  // If the offset is in padding, cancel the rewrite.
261  if (offset < distanceToStart)
262  return true;
263  }
264 
265  if (offset < distanceToStart + elemSize) {
266  // The offset is within this element, stop iterating the
267  // struct and look within the current element.
268  equivalentIndicesOut.push_back(index);
269  currentType = elem;
270  return false;
271  }
272 
273  // The offset is not within this element, continue walking over
274  // the struct.
275  distanceToStart += elemSize;
276  index++;
277  }
278 
279  // The offset was supposed to be within this struct but is not.
280  // This can happen if the offset points into final padding.
281  // Anyway, nothing can be done.
282  return true;
283  })
284  .Default([](Type) {
285  // If the offset is within a type that cannot be split, no indices
286  // will yield this offset. This can happen if the offset is not
287  // perfectly aligned with a leaf type.
288  // TODO: support vectors.
289  return true;
290  });
291 
292  if (shouldCancel)
293  return failure();
294  }
295 
296  return success();
297 }
298 
299 /// Returns the consistent type for the GEP if the GEP is not type-consistent.
300 /// Returns failure if the GEP is already consistent.
302  // GEP of typed pointers are not supported.
303  if (!gep.getElemType())
304  return failure();
305 
306  std::optional<Type> maybeBaseType = gep.getElemType();
307  if (!maybeBaseType)
308  return failure();
309  Type baseType = *maybeBaseType;
310 
311  Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType);
312  if (!typeHint)
313  return failure();
314  return typeHint;
315 }
316 
319  PatternRewriter &rewriter) const {
321  if (failed(typeHint)) {
322  // GEP is already canonical, nothing to do here.
323  return failure();
324  }
325 
326  DataLayout layout = DataLayout::closest(gep);
327  std::optional<uint64_t> desiredOffset = gepToByteOffset(layout, gep);
328  if (!desiredOffset)
329  return failure();
330 
331  SmallVector<GEPArg> newIndices;
332  if (failed(
333  findIndicesForOffset(layout, *typeHint, *desiredOffset, newIndices)))
334  return failure();
335 
336  rewriter.replaceOpWithNewOp<GEPOp>(
337  gep, LLVM::LLVMPointerType::get(getContext()), *typeHint, gep.getBase(),
338  newIndices, gep.getInbounds());
339 
340  return success();
341 }
342 
343 namespace {
344 /// Class abstracting over both array and struct types, turning each into ranges
345 /// of their sub-types.
346 class DestructurableTypeRange
347  : public llvm::indexed_accessor_range<DestructurableTypeRange,
348  DestructurableTypeInterface, Type,
349  Type *, Type> {
350 
351  using Base = llvm::indexed_accessor_range<
352  DestructurableTypeRange, DestructurableTypeInterface, Type, Type *, Type>;
353 
354 public:
355  using Base::Base;
356 
357  /// Constructs a DestructurableTypeRange from either a LLVMStructType or
358  /// LLVMArrayType.
359  explicit DestructurableTypeRange(DestructurableTypeInterface base)
360  : Base(base, 0, [&]() -> ptrdiff_t {
361  return TypeSwitch<DestructurableTypeInterface, ptrdiff_t>(base)
362  .Case([](LLVMStructType structType) {
363  return structType.getBody().size();
364  })
365  .Case([](LLVMArrayType arrayType) {
366  return arrayType.getNumElements();
367  })
368  .Default([](auto) -> ptrdiff_t {
369  llvm_unreachable(
370  "Only LLVMStructType or LLVMArrayType supported");
371  });
372  }()) {}
373 
374  /// Returns true if this is a range over a packed struct.
375  bool isPacked() const {
376  if (auto structType = dyn_cast<LLVMStructType>(getBase()))
377  return structType.isPacked();
378  return false;
379  }
380 
381 private:
382  static Type dereference(DestructurableTypeInterface base, ptrdiff_t index) {
383  // i32 chosen because the implementations of ArrayType and StructType
384  // specifically expect it to be 32 bit. They will fail otherwise.
385  Type result = base.getTypeAtIndex(
386  IntegerAttr::get(IntegerType::get(base.getContext(), 32), index));
387  assert(result && "Should always succeed");
388  return result;
389  }
390 
391  friend Base;
392 };
393 } // namespace
394 
395 /// Returns the list of elements of `destructurableType` that are written to by
396 /// a store operation writing `storeSize` bytes at `storeOffset`.
397 /// `storeOffset` is required to cleanly point to an immediate element within
398 /// the type. If the write operation were to write to any padding, write beyond
399 /// the aggregate or partially write to a non-aggregate, failure is returned.
401 getWrittenToFields(const DataLayout &dataLayout,
402  DestructurableTypeInterface destructurableType,
403  unsigned storeSize, unsigned storeOffset) {
404  DestructurableTypeRange destructurableTypeRange(destructurableType);
405 
406  unsigned currentOffset = 0;
407  for (; !destructurableTypeRange.empty();
408  destructurableTypeRange = destructurableTypeRange.drop_front()) {
409  Type type = destructurableTypeRange.front();
410  if (!destructurableTypeRange.isPacked()) {
411  unsigned alignment = dataLayout.getTypeABIAlignment(type);
412  currentOffset = llvm::alignTo(currentOffset, alignment);
413  }
414 
415  // currentOffset is guaranteed to be equal to offset since offset is either
416  // 0 or stems from a type-consistent GEP indexing into just a single
417  // aggregate.
418  if (currentOffset == storeOffset)
419  break;
420 
421  assert(currentOffset < storeOffset &&
422  "storeOffset should cleanly point into an immediate field");
423 
424  currentOffset += dataLayout.getTypeSize(type);
425  }
426 
427  size_t exclusiveEnd = 0;
428  for (; exclusiveEnd < destructurableTypeRange.size() && storeSize > 0;
429  exclusiveEnd++) {
430  if (!destructurableTypeRange.isPacked()) {
431  unsigned alignment =
432  dataLayout.getTypeABIAlignment(destructurableTypeRange[exclusiveEnd]);
433  // No padding allowed inbetween fields at this point in time.
434  if (!llvm::isAligned(llvm::Align(alignment), currentOffset))
435  return failure();
436  }
437 
438  unsigned fieldSize =
439  dataLayout.getTypeSize(destructurableTypeRange[exclusiveEnd]);
440  if (fieldSize > storeSize) {
441  // Partial writes into an aggregate are okay since subsequent pattern
442  // applications can further split these up into writes into the
443  // sub-elements.
444  auto subAggregate = dyn_cast<DestructurableTypeInterface>(
445  destructurableTypeRange[exclusiveEnd]);
446  if (!subAggregate)
447  return failure();
448 
449  // Avoid splitting redundantly by making sure the store into the
450  // aggregate can actually be split.
451  if (failed(getWrittenToFields(dataLayout, subAggregate, storeSize,
452  /*storeOffset=*/0)))
453  return failure();
454 
455  return destructurableTypeRange.take_front(exclusiveEnd + 1);
456  }
457  currentOffset += fieldSize;
458  storeSize -= fieldSize;
459  }
460 
461  // If the storeSize is not 0 at this point we are writing past the aggregate
462  // as a whole. Abort.
463  if (storeSize > 0)
464  return failure();
465  return destructurableTypeRange.take_front(exclusiveEnd);
466 }
467 
468 /// Splits a store of the vector `value` into `address` at `storeOffset` into
469 /// multiple stores of each element with the goal of each generated store
470 /// becoming type-consistent through subsequent pattern applications.
471 static void splitVectorStore(const DataLayout &dataLayout, Location loc,
472  RewriterBase &rewriter, Value address,
474  unsigned storeOffset) {
475  VectorType vectorType = value.getType();
476  unsigned elementSize = dataLayout.getTypeSize(vectorType.getElementType());
477 
478  // Extract every element in the vector and store it in the given address.
479  for (size_t index : llvm::seq<size_t>(0, vectorType.getNumElements())) {
480  auto pos =
481  rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(index));
482  auto extractOp = rewriter.create<ExtractElementOp>(loc, value, pos);
483 
484  // For convenience, we do indexing by calculating the final byte offset.
485  // Other patterns will turn this into a type-consistent GEP.
486  auto gepOp = rewriter.create<GEPOp>(
487  loc, address.getType(), rewriter.getI8Type(), address,
488  ArrayRef<GEPArg>{storeOffset + index * elementSize});
489 
490  rewriter.create<StoreOp>(loc, extractOp, gepOp);
491  }
492 }
493 
494 /// Splits a store of the integer `value` into `address` at `storeOffset` into
495 /// multiple stores to each 'writtenToFields', making each store operation
496 /// type-consistent.
497 static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
498  RewriterBase &rewriter, Value address,
499  Value value, unsigned storeSize,
500  unsigned storeOffset,
501  DestructurableTypeRange writtenToFields) {
502  unsigned currentOffset = storeOffset;
503  for (Type type : writtenToFields) {
504  unsigned fieldSize = dataLayout.getTypeSize(type);
505 
506  // Extract the data out of the integer by first shifting right and then
507  // truncating it.
508  auto pos = rewriter.create<ConstantOp>(
509  loc, rewriter.getIntegerAttr(value.getType(),
510  (currentOffset - storeOffset) * 8));
511 
512  auto shrOp = rewriter.create<LShrOp>(loc, value, pos);
513 
514  // If we are doing a partial write into a direct field the remaining
515  // `storeSize` will be less than the size of the field. We have to truncate
516  // to the `storeSize` to avoid creating a store that wasn't in the original
517  // code.
518  IntegerType fieldIntType =
519  rewriter.getIntegerType(std::min(fieldSize, storeSize) * 8);
520  Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
521 
522  // We create an `i8` indexed GEP here as that is the easiest (offset is
523  // already known). Other patterns turn this into a type-consistent GEP.
524  auto gepOp =
525  rewriter.create<GEPOp>(loc, address.getType(), rewriter.getI8Type(),
526  address, ArrayRef<GEPArg>{currentOffset});
527  rewriter.create<StoreOp>(loc, valueToStore, gepOp);
528 
529  // No need to care about padding here since we already checked previously
530  // that no padding exists in this range.
531  currentOffset += fieldSize;
532  storeSize -= fieldSize;
533  }
534 }
535 
536 LogicalResult SplitStores::matchAndRewrite(StoreOp store,
537  PatternRewriter &rewriter) const {
538  Type sourceType = store.getValue().getType();
539  if (!isa<IntegerType, VectorType>(sourceType)) {
540  // We currently only support integer and vector sources.
541  return failure();
542  }
543 
544  Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
545  if (!typeHint) {
546  // Nothing to do, since it is already consistent.
547  return failure();
548  }
549 
550  auto dataLayout = DataLayout::closest(store);
551 
552  unsigned storeSize = dataLayout.getTypeSize(sourceType);
553  unsigned offset = 0;
554  Value address = store.getAddr();
555  if (auto gepOp = address.getDefiningOp<GEPOp>()) {
556  // Currently only handle canonical GEPs with exactly two indices,
557  // indexing a single aggregate deep.
558  // If the GEP is not canonical we have to fail, otherwise we would not
559  // create type-consistent IR.
560  if (gepOp.getIndices().size() != 2 ||
562  return failure();
563 
564  // If the size of the element indexed by the GEP is smaller than the store
565  // size, it is pointing into the middle of an aggregate with the store
566  // storing into multiple adjacent elements. Destructure into the base
567  // address of the aggregate with a store offset.
568  if (storeSize > dataLayout.getTypeSize(gepOp.getResultPtrElementType())) {
569  std::optional<uint64_t> byteOffset = gepToByteOffset(dataLayout, gepOp);
570  if (!byteOffset)
571  return failure();
572 
573  offset = *byteOffset;
574  typeHint = gepOp.getElemType();
575  address = gepOp.getBase();
576  }
577  }
578 
579  auto destructurableType = typeHint.dyn_cast<DestructurableTypeInterface>();
580  if (!destructurableType)
581  return failure();
582 
583  FailureOr<DestructurableTypeRange> writtenToElements =
584  getWrittenToFields(dataLayout, destructurableType, storeSize, offset);
585  if (failed(writtenToElements))
586  return failure();
587 
588  if (writtenToElements->size() <= 1) {
589  // Other patterns should take care of this case, we are only interested in
590  // splitting element stores.
591  return failure();
592  }
593 
594  if (isa<IntegerType>(sourceType)) {
595  splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
596  store.getValue(), storeSize, offset, *writtenToElements);
597  rewriter.eraseOp(store);
598  return success();
599  }
600 
601  // Add a reasonable bound to not split very large vectors that would end up
602  // generating lots of code.
603  if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorSplitSize)
604  return failure();
605 
606  // Vector types are simply split into its elements and new stores generated
607  // with those. Subsequent pattern applications will split these stores further
608  // if required.
609  splitVectorStore(dataLayout, store.getLoc(), rewriter, address,
610  cast<TypedValue<VectorType>>(store.getValue()), offset);
611  rewriter.eraseOp(store);
612  return success();
613 }
614 
615 LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
616  PatternRewriter &rewriter) const {
617  Type sourceType = store.getValue().getType();
618  Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
619  if (!typeHint) {
620  // Nothing to do, since it is already consistent.
621  return failure();
622  }
623 
624  auto dataLayout = DataLayout::closest(store);
625  if (!areBitcastCompatible(dataLayout, typeHint, sourceType))
626  return failure();
627 
628  auto bitcastOp =
629  rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
630  rewriter.updateRootInPlace(
631  store, [&] { store.getValueMutable().assign(bitcastOp); });
632  return success();
633 }
634 
635 LogicalResult SplitGEP::matchAndRewrite(GEPOp gepOp,
636  PatternRewriter &rewriter) const {
638  if (succeeded(typeHint) || gepOp.getIndices().size() <= 2) {
639  // GEP is not canonical or a single aggregate deep, nothing to do here.
640  return failure();
641  }
642 
643  auto indexToGEPArg =
645  if (auto integerAttr = dyn_cast<IntegerAttr>(index))
646  return integerAttr.getValue().getSExtValue();
647  return cast<Value>(index);
648  };
649 
650  GEPIndicesAdaptor<ValueRange> indices = gepOp.getIndices();
651 
652  auto splitIter = std::next(indices.begin(), 2);
653 
654  // Split of the first GEP using the first two indices.
655  auto subGepOp = rewriter.create<GEPOp>(
656  gepOp.getLoc(), gepOp.getType(), gepOp.getElemType(), gepOp.getBase(),
657  llvm::map_to_vector(llvm::make_range(indices.begin(), splitIter),
658  indexToGEPArg),
659  gepOp.getInbounds());
660 
661  // The second GEP indexes on the result pointer element type of the previous
662  // with all the remaining indices and a zero upfront. If this GEP has more
663  // than two indices remaining it'll be further split in subsequent pattern
664  // applications.
665  SmallVector<GEPArg> newIndices = {0};
666  llvm::transform(llvm::make_range(splitIter, indices.end()),
667  std::back_inserter(newIndices), indexToGEPArg);
668  rewriter.replaceOpWithNewOp<GEPOp>(gepOp, gepOp.getType(),
669  subGepOp.getResultPtrElementType(),
670  subGepOp, newIndices, gepOp.getInbounds());
671  return success();
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // Type consistency pass
676 //===----------------------------------------------------------------------===//
677 
678 namespace {
679 struct LLVMTypeConsistencyPass
680  : public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
681  void runOnOperation() override {
682  RewritePatternSet rewritePatterns(&getContext());
683  rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
684  rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
685  &getContext());
686  rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
687  rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
688  rewritePatterns.add<BitcastStores>(&getContext());
689  rewritePatterns.add<SplitGEP>(&getContext());
690  FrozenRewritePatternSet frozen(std::move(rewritePatterns));
691 
692  if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))
693  signalPassFailure();
694  }
695 };
696 } // namespace
697 
698 std::unique_ptr<Pass> LLVM::createTypeConsistencyPass() {
699  return std::make_unique<LLVMTypeConsistencyPass>();
700 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset, SmallVectorImpl< GEPArg > &equivalentIndicesOut)
Fills in equivalentIndicesOut with GEP indices that would be equivalent to offsetting a pointer by of...
static Type isElementTypeInconsistent(Value addr, Type expectedType)
Checks that a pointer value has a pointee type hint consistent with the expected type.
static void splitIntegerStore(const DataLayout &dataLayout, Location loc, RewriterBase &rewriter, Value address, Value value, unsigned storeSize, unsigned storeOffset, DestructurableTypeRange writtenToFields)
Splits a store of the integer value into address at storeOffset into multiple stores to each 'written...
static std::optional< uint64_t > gepToByteOffset(DataLayout &layout, GEPOp gep)
Returns the amount of bytes the provided GEP elements will offset the pointer by.
static FailureOr< Type > getRequiredConsistentGEPType(GEPOp gep)
Returns the consistent type for the GEP if the GEP is not type-consistent.
static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter, Type elemType)
Extracts a pointer to the first field of an elemType from the address pointer of the provided MemOp,...
static Type getFirstSubelementType(Type type)
Gets the type of the first subelement of type if type is destructurable, nullptr otherwise.
static FailureOr< DestructurableTypeRange > getWrittenToFields(const DataLayout &dataLayout, DestructurableTypeInterface destructurableType, unsigned storeSize, unsigned storeOffset)
Returns the list of elements of destructurableType that are written to by a store operation writing s...
static void splitVectorStore(const DataLayout &dataLayout, Location loc, RewriterBase &rewriter, Value address, TypedValue< VectorType > value, unsigned storeOffset)
Splits a store of the vector value into address at storeOffset into multiple stores of each element w...
static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs)
Checks that two types are the same or can be bitcast into one another.
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IntegerType getI8Type()
Definition: Builders.cpp:79
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a frozen set of patterns that can be processed by a pattern applicator.
Transforms uses of pointers to a whole struct to uses of pointers to the first element of a struct.
LogicalResult matchAndRewrite(User user, PatternRewriter &rewriter) const override
Transforms type-inconsistent stores, aka stores where the type hint of the address contradicts the va...
Canonicalizes GEPs of which the base type and the pointer's type hint do not match.
LogicalResult matchAndRewrite(GEPOp gep, PatternRewriter &rewriter) const override
Class used for building a 'llvm.getelementptr'.
Definition: LLVMDialect.h:75
Class used for convenient access and iteration over GEP indices.
Definition: LLVMDialect.h:115
iterator begin() const
Returns the begin iterator, iterating over all GEP indices.
Definition: LLVMDialect.h:192
std::conditional_t< std::is_base_of< Attribute, llvm::detail::ValueOfRange< DynamicRange > >::value, Attribute, PointerUnion< IntegerAttr, llvm::detail::ValueOfRange< DynamicRange > >> value_type
Return type of 'operator[]' and the iterators 'operator*'.
Definition: LLVMDialect.h:126
iterator end() const
Returns the end iterator, iterating over all GEP indices.
Definition: LLVMDialect.h:198
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:109
bool isPacked() const
Checks if a struct is packed.
Definition: LLVMTypes.cpp:465
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:473
Splits GEPs with more than two indices into multiple GEPs with exactly two indices.
Splits stores which write into multiple adjacent elements of an aggregate through a pointer.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:133
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:646
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
U dyn_cast() const
Definition: Types.h:329
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
std::unique_ptr< Pass > createTypeConsistencyPass()
Creates a pass that adjusts operations operating on pointers so they interpret pointee types as consi...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26