MLIR  19.0.0git
LLVMMemorySlot.cpp
Go to the documentation of this file.
1 //===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MemorySlot-related interfaces for LLVM dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define DEBUG_TYPE "sroa"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Interfaces for AllocaOp
29 //===----------------------------------------------------------------------===//
30 
31 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
32  if (!getOperation()->getBlock()->isEntryBlock())
33  return {};
34 
35  return {MemorySlot{getResult(), getElemType()}};
36 }
37 
38 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
39  RewriterBase &rewriter) {
40  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
41 }
42 
43 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
44  BlockArgument argument,
45  RewriterBase &rewriter) {
46  for (Operation *user : getOperation()->getUsers())
47  if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48  rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49  declareOp.getVarInfo(),
50  declareOp.getLocationExpr());
51 }
52 
53 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
54  Value defaultValue,
55  RewriterBase &rewriter) {
56  if (defaultValue && defaultValue.use_empty())
57  rewriter.eraseOp(defaultValue.getDefiningOp());
58  rewriter.eraseOp(*this);
59 }
60 
61 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
62  if (!mlir::matchPattern(getArraySize(), m_One()))
63  return {};
64 
65  auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType());
66  if (!destructurable)
67  return {};
68 
69  std::optional<DenseMap<Attribute, Type>> destructuredType =
70  destructurable.getSubelementIndexMap();
71  if (!destructuredType)
72  return {};
73 
74  return {DestructurableMemorySlot{{getResult(), getElemType()},
75  *destructuredType}};
76 }
77 
79 LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
80  const SmallPtrSetImpl<Attribute> &usedIndices,
81  RewriterBase &rewriter) {
82  assert(slot.ptr == getResult());
83  rewriter.setInsertionPointAfter(*this);
84 
85  auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
87  for (Attribute index : usedIndices) {
88  Type elemType = destructurableType.getTypeAtIndex(index);
89  assert(elemType && "used index must exist");
90  auto subAlloca = rewriter.create<LLVM::AllocaOp>(
91  getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
92  getArraySize());
93  slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
94  }
95 
96  return slotMap;
97 }
98 
99 void LLVM::AllocaOp::handleDestructuringComplete(
100  const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
101  assert(slot.ptr == getResult());
102  rewriter.eraseOp(*this);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Interfaces for LoadOp/StoreOp
107 //===----------------------------------------------------------------------===//
108 
109 bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
110  return getAddr() == slot.ptr;
111 }
112 
113 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
114 
115 Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
116  Value reachingDef, const DataLayout &dataLayout) {
117  llvm_unreachable("getStored should not be called on LoadOp");
118 }
119 
120 bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
121 
122 bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
123  return getAddr() == slot.ptr;
124 }
125 
126 /// Checks if `type` can be used in any kind of conversion sequences.
128  // Aggregate types are not bitcastable.
129  if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
130  return false;
131 
132  // LLVM vector types are only used for either pointers or target specific
133  // types. These types cannot be casted in the general case, thus the memory
134  // optimizations do not support them.
135  if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
136  return false;
137 
138  // Scalable types are not supported.
139  if (auto vectorType = dyn_cast<VectorType>(type))
140  return !vectorType.isScalable();
141  return true;
142 }
143 
144 /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145 /// truncations. Checks for narrowing or widening conversion compatibility
146 /// depending on `narrowingConversion`.
147 static bool areConversionCompatible(const DataLayout &layout, Type targetType,
148  Type srcType, bool narrowingConversion) {
149  if (targetType == srcType)
150  return true;
151 
152  if (!isSupportedTypeForConversion(targetType) ||
154  return false;
155 
156  uint64_t targetSize = layout.getTypeSize(targetType);
157  uint64_t srcSize = layout.getTypeSize(srcType);
158 
159  // Pointer casts will only be sane when the bitsize of both pointer types is
160  // the same.
161  if (isa<LLVM::LLVMPointerType>(targetType) &&
162  isa<LLVM::LLVMPointerType>(srcType))
163  return targetSize == srcSize;
164 
165  if (narrowingConversion)
166  return targetSize <= srcSize;
167  return targetSize >= srcSize;
168 }
169 
170 /// Checks if `dataLayout` describes a little endian layout.
171 static bool isBigEndian(const DataLayout &dataLayout) {
172  auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
173  return endiannessStr && endiannessStr == "big";
174 }
175 
176 /// Converts a value to an integer type of the same size.
177 /// Assumes that the type can be converted.
179  const DataLayout &dataLayout) {
180  Type type = val.getType();
181  assert(isSupportedTypeForConversion(type) &&
182  "expected value to have a convertible type");
183 
184  if (isa<IntegerType>(type))
185  return val;
186 
187  uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
188  IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
189 
190  if (isa<LLVM::LLVMPointerType>(type))
191  return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
192  return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
193 }
194 
195 /// Converts a value with an integer type to `targetType`.
197  Value val, Type targetType) {
198  assert(isa<IntegerType>(val.getType()) &&
199  "expected value to have an integer type");
200  assert(isSupportedTypeForConversion(targetType) &&
201  "expected the target type to be supported for conversions");
202  if (val.getType() == targetType)
203  return val;
204  if (isa<LLVM::LLVMPointerType>(targetType))
205  return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
206  return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
207 }
208 
209 /// Constructs operations that convert `srcValue` into a new value of type
210 /// `targetType`. Assumes the types have the same bitsize.
212  Value srcValue, Type targetType,
213  const DataLayout &dataLayout) {
214  Type srcType = srcValue.getType();
215  assert(areConversionCompatible(dataLayout, targetType, srcType,
216  /*narrowingConversion=*/true) &&
217  "expected that the compatibility was checked before");
218 
219  // Nothing has to be done if the types are already the same.
220  if (srcType == targetType)
221  return srcValue;
222 
223  // In the special case of casting one pointer to another, we want to generate
224  // an address space cast. Bitcasts of pointers are not allowed and using
225  // pointer to integer conversions are not equivalent due to the loss of
226  // provenance.
227  if (isa<LLVM::LLVMPointerType>(targetType) &&
228  isa<LLVM::LLVMPointerType>(srcType))
229  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
230  srcValue);
231 
232  // For all other castable types, casting through integers is necessary.
233  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
234  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
235 }
236 
237 /// Constructs operations that convert `srcValue` into a new value of type
238 /// `targetType`. Performs bit-level extraction if the source type is larger
239 /// than the target type. Assumes that this conversion is possible.
241  Value srcValue, Type targetType,
242  const DataLayout &dataLayout) {
243  // Get the types of the source and target values.
244  Type srcType = srcValue.getType();
245  assert(areConversionCompatible(dataLayout, targetType, srcType,
246  /*narrowingConversion=*/true) &&
247  "expected that the compatibility was checked before");
248 
249  uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
250  uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
251  if (srcTypeSize == targetTypeSize)
252  return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
253 
254  // First, cast the value to a same-sized integer type.
255  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
256 
257  // Truncate the integer if the size of the target is less than the value.
258  if (isBigEndian(dataLayout)) {
259  uint64_t shiftAmount = srcTypeSize - targetTypeSize;
260  auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
261  loc, rewriter.getIntegerAttr(srcType, shiftAmount));
262  replacement =
263  rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
264  }
265 
266  replacement = rewriter.create<LLVM::TruncOp>(
267  loc, rewriter.getIntegerType(targetTypeSize), replacement);
268 
269  // Now cast the integer to the actual target type if required.
270  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
271 }
272 
273 /// Constructs operations that insert the bits of `srcValue` into the
274 /// "beginning" of `reachingDef` (beginning is endianness dependent).
275 /// Assumes that this conversion is possible.
277  Value srcValue, Value reachingDef,
278  const DataLayout &dataLayout) {
279 
280  assert(areConversionCompatible(dataLayout, reachingDef.getType(),
281  srcValue.getType(),
282  /*narrowingConversion=*/false) &&
283  "expected that the compatibility was checked before");
284  uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
285  uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
286  if (slotTypeSize == valueTypeSize)
287  return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
288  dataLayout);
289 
290  // In the case where the store only overwrites parts of the memory,
291  // bit fiddling is required to construct the new value.
292 
293  // First convert both values to integers of the same size.
294  Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
295  Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
296  // Extend the value to the size of the reaching definition.
297  valueAsInt =
298  rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
299  uint64_t sizeDifference = slotTypeSize - valueTypeSize;
300  if (isBigEndian(dataLayout)) {
301  // On big endian systems, a store to the base pointer overwrites the most
302  // significant bits. To accomodate for this, the stored value needs to be
303  // shifted into the according position.
304  Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
305  loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
306  valueAsInt =
307  rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
308  }
309 
310  // Construct the mask that is used to erase the bits that are overwritten by
311  // the store.
312  APInt maskValue;
313  if (isBigEndian(dataLayout)) {
314  // Build a mask that has the most significant bits set to zero.
315  // Note: This is the same as 2^sizeDifference - 1
316  maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
317  } else {
318  // Build a mask that has the least significant bits set to zero.
319  // Note: This is the same as -(2^valueTypeSize)
320  maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
321  maskValue.flipAllBits();
322  }
323 
324  // Mask out the affected bits ...
325  Value mask = rewriter.create<LLVM::ConstantOp>(
326  loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
327  Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
328 
329  // ... and combine the result with the new value.
330  Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
331 
332  return castIntValueToSameSizedType(rewriter, loc, combined,
333  reachingDef.getType());
334 }
335 
336 Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
337  Value reachingDef,
338  const DataLayout &dataLayout) {
339  assert(reachingDef && reachingDef.getType() == slot.elemType &&
340  "expected the reaching definition's type to match the slot's type");
341  return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
342  dataLayout);
343 }
344 
345 bool LLVM::LoadOp::canUsesBeRemoved(
346  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
347  SmallVectorImpl<OpOperand *> &newBlockingUses,
348  const DataLayout &dataLayout) {
349  if (blockingUses.size() != 1)
350  return false;
351  Value blockingUse = (*blockingUses.begin())->get();
352  // If the blocking use is the slot ptr itself, there will be enough
353  // context to reconstruct the result of the load at removal time, so it can
354  // be removed (provided it is not volatile).
355  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
356  areConversionCompatible(dataLayout, getResult().getType(),
357  slot.elemType, /*narrowingConversion=*/true) &&
358  !getVolatile_();
359 }
360 
361 DeletionKind LLVM::LoadOp::removeBlockingUses(
362  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
363  RewriterBase &rewriter, Value reachingDefinition,
364  const DataLayout &dataLayout) {
365  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
366  // pointer.
367  Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
368  getResult().getType(), dataLayout);
369  rewriter.replaceAllUsesWith(getResult(), newResult);
370  return DeletionKind::Delete;
371 }
372 
373 bool LLVM::StoreOp::canUsesBeRemoved(
374  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
375  SmallVectorImpl<OpOperand *> &newBlockingUses,
376  const DataLayout &dataLayout) {
377  if (blockingUses.size() != 1)
378  return false;
379  Value blockingUse = (*blockingUses.begin())->get();
380  // If the blocking use is the slot ptr itself, dropping the store is
381  // fine, provided we are currently promoting its target value. Don't allow a
382  // store OF the slot pointer, only INTO the slot pointer.
383  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
384  getValue() != slot.ptr &&
385  areConversionCompatible(dataLayout, slot.elemType,
386  getValue().getType(),
387  /*narrowingConversion=*/false) &&
388  !getVolatile_();
389 }
390 
391 DeletionKind LLVM::StoreOp::removeBlockingUses(
392  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
393  RewriterBase &rewriter, Value reachingDefinition,
394  const DataLayout &dataLayout) {
395  return DeletionKind::Delete;
396 }
397 
398 /// Checks if `slot` can be accessed through the provided access type.
399 static bool isValidAccessType(const MemorySlot &slot, Type accessType,
400  const DataLayout &dataLayout) {
401  return dataLayout.getTypeSize(accessType) <=
402  dataLayout.getTypeSize(slot.elemType);
403 }
404 
405 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
406  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
407  const DataLayout &dataLayout) {
408  return success(getAddr() != slot.ptr ||
409  isValidAccessType(slot, getType(), dataLayout));
410 }
411 
412 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
413  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
414  const DataLayout &dataLayout) {
415  return success(getAddr() != slot.ptr ||
416  isValidAccessType(slot, getValue().getType(), dataLayout));
417 }
418 
419 /// Returns the subslot's type at the requested index.
421  Attribute index) {
422  auto subelementIndexMap =
423  cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
424  if (!subelementIndexMap)
425  return {};
426  assert(!subelementIndexMap->empty());
427 
428  // Note: Returns a null-type when no entry was found.
429  return subelementIndexMap->lookup(index);
430 }
431 
432 bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
433  SmallPtrSetImpl<Attribute> &usedIndices,
434  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
435  const DataLayout &dataLayout) {
436  if (getVolatile_())
437  return false;
438 
439  // A load always accesses the first element of the destructured slot.
440  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
441  Type subslotType = getTypeAtIndex(slot, index);
442  if (!subslotType)
443  return false;
444 
445  // The access can only be replaced when the subslot is read within its bounds.
446  if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
447  return false;
448 
449  usedIndices.insert(index);
450  return true;
451 }
452 
453 DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
455  RewriterBase &rewriter,
456  const DataLayout &dataLayout) {
457  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
458  auto it = subslots.find(index);
459  assert(it != subslots.end());
460 
461  rewriter.modifyOpInPlace(
462  *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
463  return DeletionKind::Keep;
464 }
465 
466 bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
467  SmallPtrSetImpl<Attribute> &usedIndices,
468  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
469  const DataLayout &dataLayout) {
470  if (getVolatile_())
471  return false;
472 
473  // Storing the pointer to memory cannot be dealt with.
474  if (getValue() == slot.ptr)
475  return false;
476 
477  // A store always accesses the first element of the destructured slot.
478  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
479  Type subslotType = getTypeAtIndex(slot, index);
480  if (!subslotType)
481  return false;
482 
483  // The access can only be replaced when the subslot is read within its bounds.
484  if (dataLayout.getTypeSize(getValue().getType()) >
485  dataLayout.getTypeSize(subslotType))
486  return false;
487 
488  usedIndices.insert(index);
489  return true;
490 }
491 
492 DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
494  RewriterBase &rewriter,
495  const DataLayout &dataLayout) {
496  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
497  auto it = subslots.find(index);
498  assert(it != subslots.end());
499 
500  rewriter.modifyOpInPlace(
501  *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
502  return DeletionKind::Keep;
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // Interfaces for discardable OPs
507 //===----------------------------------------------------------------------===//
508 
509 /// Conditions the deletion of the operation to the removal of all its uses.
510 static bool forwardToUsers(Operation *op,
511  SmallVectorImpl<OpOperand *> &newBlockingUses) {
512  for (Value result : op->getResults())
513  for (OpOperand &use : result.getUses())
514  newBlockingUses.push_back(&use);
515  return true;
516 }
517 
518 bool LLVM::BitcastOp::canUsesBeRemoved(
519  const SmallPtrSetImpl<OpOperand *> &blockingUses,
520  SmallVectorImpl<OpOperand *> &newBlockingUses,
521  const DataLayout &dataLayout) {
522  return forwardToUsers(*this, newBlockingUses);
523 }
524 
525 DeletionKind LLVM::BitcastOp::removeBlockingUses(
526  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
527  return DeletionKind::Delete;
528 }
529 
530 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
531  const SmallPtrSetImpl<OpOperand *> &blockingUses,
532  SmallVectorImpl<OpOperand *> &newBlockingUses,
533  const DataLayout &dataLayout) {
534  return forwardToUsers(*this, newBlockingUses);
535 }
536 
537 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
538  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
539  return DeletionKind::Delete;
540 }
541 
542 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
543  const SmallPtrSetImpl<OpOperand *> &blockingUses,
544  SmallVectorImpl<OpOperand *> &newBlockingUses,
545  const DataLayout &dataLayout) {
546  return true;
547 }
548 
549 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
550  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
551  return DeletionKind::Delete;
552 }
553 
554 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
555  const SmallPtrSetImpl<OpOperand *> &blockingUses,
556  SmallVectorImpl<OpOperand *> &newBlockingUses,
557  const DataLayout &dataLayout) {
558  return true;
559 }
560 
561 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
562  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
563  return DeletionKind::Delete;
564 }
565 
566 bool LLVM::InvariantStartOp::canUsesBeRemoved(
567  const SmallPtrSetImpl<OpOperand *> &blockingUses,
568  SmallVectorImpl<OpOperand *> &newBlockingUses,
569  const DataLayout &dataLayout) {
570  return true;
571 }
572 
573 DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
574  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
575  return DeletionKind::Delete;
576 }
577 
578 bool LLVM::InvariantEndOp::canUsesBeRemoved(
579  const SmallPtrSetImpl<OpOperand *> &blockingUses,
580  SmallVectorImpl<OpOperand *> &newBlockingUses,
581  const DataLayout &dataLayout) {
582  return true;
583 }
584 
585 DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
586  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
587  return DeletionKind::Delete;
588 }
589 
590 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
591  const SmallPtrSetImpl<OpOperand *> &blockingUses,
592  SmallVectorImpl<OpOperand *> &newBlockingUses,
593  const DataLayout &dataLayout) {
594  return true;
595 }
596 
597 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
598  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
599  return DeletionKind::Delete;
600 }
601 
602 bool LLVM::DbgValueOp::canUsesBeRemoved(
603  const SmallPtrSetImpl<OpOperand *> &blockingUses,
604  SmallVectorImpl<OpOperand *> &newBlockingUses,
605  const DataLayout &dataLayout) {
606  // There is only one operand that we can remove the use of.
607  if (blockingUses.size() != 1)
608  return false;
609 
610  return (*blockingUses.begin())->get() == getValue();
611 }
612 
613 DeletionKind LLVM::DbgValueOp::removeBlockingUses(
614  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
615  // Rewriter by default is after '*this', but we need it before '*this'.
616  rewriter.setInsertionPoint(*this);
617 
618  // Rather than dropping the debug value, replace it with undef to preserve the
619  // debug local variable info. This allows the debugger to inform the user that
620  // the variable has been optimized out.
621  auto undef =
622  rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
623  rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); });
624  return DeletionKind::Keep;
625 }
626 
627 bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
628 
629 void LLVM::DbgDeclareOp::visitReplacedValues(
630  ArrayRef<std::pair<Operation *, Value>> definitions,
631  RewriterBase &rewriter) {
632  for (auto [op, value] : definitions) {
633  rewriter.setInsertionPointAfter(op);
634  rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
635  getLocationExpr());
636  }
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // Interfaces for GEPOp
641 //===----------------------------------------------------------------------===//
642 
643 static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
644  return llvm::all_of(gepOp.getIndices(), [](auto index) {
645  auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
646  return indexAttr && indexAttr.getValue() == 0;
647  });
648 }
649 
650 bool LLVM::GEPOp::canUsesBeRemoved(
651  const SmallPtrSetImpl<OpOperand *> &blockingUses,
652  SmallVectorImpl<OpOperand *> &newBlockingUses,
653  const DataLayout &dataLayout) {
654  // GEP can be removed as long as it is a no-op and its users can be removed.
655  if (!hasAllZeroIndices(*this))
656  return false;
657  return forwardToUsers(*this, newBlockingUses);
658 }
659 
660 DeletionKind LLVM::GEPOp::removeBlockingUses(
661  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
662  return DeletionKind::Delete;
663 }
664 
665 /// Returns the amount of bytes the provided GEP elements will offset the
666 /// pointer by. Returns nullopt if no constant offset could be computed.
667 static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
668  LLVM::GEPOp gep) {
669  // Collects all indices.
670  SmallVector<uint64_t> indices;
671  for (auto index : gep.getIndices()) {
672  auto constIndex = dyn_cast<IntegerAttr>(index);
673  if (!constIndex)
674  return {};
675  int64_t gepIndex = constIndex.getInt();
676  // Negative indices are not supported.
677  if (gepIndex < 0)
678  return {};
679  indices.push_back(gepIndex);
680  }
681 
682  Type currentType = gep.getElemType();
683  uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
684 
685  for (uint64_t index : llvm::drop_begin(indices)) {
686  bool shouldCancel =
687  TypeSwitch<Type, bool>(currentType)
688  .Case([&](LLVM::LLVMArrayType arrayType) {
689  offset +=
690  index * dataLayout.getTypeSize(arrayType.getElementType());
691  currentType = arrayType.getElementType();
692  return false;
693  })
694  .Case([&](LLVM::LLVMStructType structType) {
695  ArrayRef<Type> body = structType.getBody();
696  assert(index < body.size() && "expected valid struct indexing");
697  for (uint32_t i : llvm::seq(index)) {
698  if (!structType.isPacked())
699  offset = llvm::alignTo(
700  offset, dataLayout.getTypeABIAlignment(body[i]));
701  offset += dataLayout.getTypeSize(body[i]);
702  }
703 
704  // Align for the current type as well.
705  if (!structType.isPacked())
706  offset = llvm::alignTo(
707  offset, dataLayout.getTypeABIAlignment(body[index]));
708  currentType = body[index];
709  return false;
710  })
711  .Default([&](Type type) {
712  LLVM_DEBUG(llvm::dbgs()
713  << "[sroa] Unsupported type for offset computations"
714  << type << "\n");
715  return true;
716  });
717 
718  if (shouldCancel)
719  return std::nullopt;
720  }
721 
722  return offset;
723 }
724 
725 namespace {
726 /// A struct that stores both the index into the aggregate type of the slot as
727 /// well as the corresponding byte offset in memory.
728 struct SubslotAccessInfo {
729  /// The parent slot's index that the access falls into.
730  uint32_t index;
731  /// The offset into the subslot of the access.
732  uint64_t subslotOffset;
733 };
734 } // namespace
735 
736 /// Computes subslot access information for an access into `slot` with the given
737 /// offset.
738 /// Returns nullopt when the offset is out-of-bounds or when the access is into
739 /// the padding of `slot`.
740 static std::optional<SubslotAccessInfo>
742  const DataLayout &dataLayout, LLVM::GEPOp gep) {
743  std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
744  if (!offset)
745  return {};
746 
747  // Helper to check that a constant index is in the bounds of the GEP index
748  // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
749  // this additional check is necessary.
750  auto isOutOfBoundsGEPIndex = [](uint64_t index) {
751  return index >= (1 << LLVM::kGEPConstantBitWidth);
752  };
753 
754  Type type = slot.elemType;
755  if (*offset >= dataLayout.getTypeSize(type))
756  return {};
758  .Case([&](LLVM::LLVMArrayType arrayType)
759  -> std::optional<SubslotAccessInfo> {
760  // Find which element of the array contains the offset.
761  uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
762  uint64_t index = *offset / elemSize;
763  if (isOutOfBoundsGEPIndex(index))
764  return {};
765  return SubslotAccessInfo{static_cast<uint32_t>(index),
766  *offset - (index * elemSize)};
767  })
768  .Case([&](LLVM::LLVMStructType structType)
769  -> std::optional<SubslotAccessInfo> {
770  uint64_t distanceToStart = 0;
771  // Walk over the elements of the struct to find in which of
772  // them the offset is.
773  for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
774  uint64_t elemSize = dataLayout.getTypeSize(elem);
775  if (!structType.isPacked()) {
776  distanceToStart = llvm::alignTo(
777  distanceToStart, dataLayout.getTypeABIAlignment(elem));
778  // If the offset is in padding, cancel the rewrite.
779  if (offset < distanceToStart)
780  return {};
781  }
782 
783  if (offset < distanceToStart + elemSize) {
784  if (isOutOfBoundsGEPIndex(index))
785  return {};
786  // The offset is within this element, stop iterating the
787  // struct and return the index.
788  return SubslotAccessInfo{static_cast<uint32_t>(index),
789  *offset - distanceToStart};
790  }
791 
792  // The offset is not within this element, continue walking
793  // over the struct.
794  distanceToStart += elemSize;
795  }
796 
797  return {};
798  });
799 }
800 
801 /// Constructs a byte array type of the given size.
802 static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
803  unsigned size) {
804  auto byteType = IntegerType::get(context, 8);
805  return LLVM::LLVMArrayType::get(context, byteType, size);
806 }
807 
808 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
809  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
810  const DataLayout &dataLayout) {
811  if (getBase() != slot.ptr)
812  return success();
813  std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
814  if (!gepOffset)
815  return failure();
816  uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
817  // Check that the access is strictly inside the slot.
818  if (*gepOffset >= slotSize)
819  return failure();
820  // Every access that remains in bounds of the remaining slot is considered
821  // legal.
822  mustBeSafelyUsed.emplace_back<MemorySlot>(
823  {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
824  return success();
825 }
826 
827 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
828  SmallPtrSetImpl<Attribute> &usedIndices,
829  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
830  const DataLayout &dataLayout) {
831  if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
832  return false;
833 
834  if (getBase() != slot.ptr)
835  return false;
836  std::optional<SubslotAccessInfo> accessInfo =
837  getSubslotAccessInfo(slot, dataLayout, *this);
838  if (!accessInfo)
839  return false;
840  auto indexAttr =
841  IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
842  assert(slot.elementPtrs.contains(indexAttr));
843  usedIndices.insert(indexAttr);
844 
845  // The remainder of the subslot should be accesses in-bounds. Thus, we create
846  // a dummy slot with the size of the remainder.
847  Type subslotType = slot.elementPtrs.lookup(indexAttr);
848  uint64_t slotSize = dataLayout.getTypeSize(subslotType);
849  LLVM::LLVMArrayType remainingSlotType =
850  getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
851  mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
852 
853  return true;
854 }
855 
856 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
858  RewriterBase &rewriter,
859  const DataLayout &dataLayout) {
860  std::optional<SubslotAccessInfo> accessInfo =
861  getSubslotAccessInfo(slot, dataLayout, *this);
862  assert(accessInfo && "expected access info to be checked before");
863  auto indexAttr =
864  IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
865  const MemorySlot &newSlot = subslots.at(indexAttr);
866 
867  auto byteType = IntegerType::get(rewriter.getContext(), 8);
868  auto newPtr = rewriter.createOrFold<LLVM::GEPOp>(
869  getLoc(), getResult().getType(), byteType, newSlot.ptr,
870  ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
871  rewriter.replaceAllUsesWith(getResult(), newPtr);
872  return DeletionKind::Delete;
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // Utilities for memory intrinsics
877 //===----------------------------------------------------------------------===//
878 
879 namespace {
880 
881 /// Returns the length of the given memory intrinsic in bytes if it can be known
882 /// at compile-time on a best-effort basis, nothing otherwise.
883 template <class MemIntr>
884 std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
885  APInt memIntrLen;
886  if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
887  return {};
888  if (memIntrLen.getBitWidth() > 64)
889  return {};
890  return memIntrLen.getZExtValue();
891 }
892 
893 /// Returns the length of the given memory intrinsic in bytes if it can be known
894 /// at compile-time on a best-effort basis, nothing otherwise.
895 /// Because MemcpyInlineOp has its length encoded as an attribute, this requires
896 /// specialized handling.
897 template <>
898 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
899  APInt memIntrLen = op.getLen();
900  if (memIntrLen.getBitWidth() > 64)
901  return {};
902  return memIntrLen.getZExtValue();
903 }
904 
905 } // namespace
906 
907 /// Returns whether one can be sure the memory intrinsic does not write outside
908 /// of the bounds of the given slot, on a best-effort basis.
909 template <class MemIntr>
910 static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
911  const DataLayout &dataLayout) {
912  if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
913  op.getDst() != slot.ptr)
914  return false;
915 
916  std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
917  return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(slot.elemType);
918 }
919 
920 /// Checks whether all indices are i32. This is used to check GEPs can index
921 /// into them.
922 static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
923  Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
924  return llvm::all_of(llvm::make_first_range(slot.elementPtrs),
925  [&](Attribute index) {
926  auto intIndex = dyn_cast<IntegerAttr>(index);
927  return intIndex && intIndex.getType() == i32;
928  });
929 }
930 
931 //===----------------------------------------------------------------------===//
932 // Interfaces for memset
933 //===----------------------------------------------------------------------===//
934 
935 bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
936 
937 bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
938  return getDst() == slot.ptr;
939 }
940 
941 Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
942  Value reachingDef,
943  const DataLayout &dataLayout) {
944  // TODO: Support non-integer types.
945  return TypeSwitch<Type, Value>(slot.elemType)
946  .Case([&](IntegerType intType) -> Value {
947  if (intType.getWidth() == 8)
948  return getVal();
949 
950  assert(intType.getWidth() % 8 == 0);
951 
952  // Build the memset integer by repeatedly shifting the value and
953  // or-ing it with the previous value.
954  uint64_t coveredBits = 8;
955  Value currentValue =
956  rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
957  while (coveredBits < intType.getWidth()) {
958  Value shiftBy =
959  rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
960  Value shifted =
961  rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
962  currentValue =
963  rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
964  coveredBits *= 2;
965  }
966 
967  return currentValue;
968  })
969  .Default([](Type) -> Value {
970  llvm_unreachable(
971  "getStored should not be called on memset to unsupported type");
972  });
973 }
974 
975 bool LLVM::MemsetOp::canUsesBeRemoved(
976  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
977  SmallVectorImpl<OpOperand *> &newBlockingUses,
978  const DataLayout &dataLayout) {
979  // TODO: Support non-integer types.
980  bool canConvertType =
982  .Case([](IntegerType intType) {
983  return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
984  })
985  .Default([](Type) { return false; });
986  if (!canConvertType)
987  return false;
988 
989  if (getIsVolatile())
990  return false;
991 
992  return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
993 }
994 
995 DeletionKind LLVM::MemsetOp::removeBlockingUses(
996  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
997  RewriterBase &rewriter, Value reachingDefinition,
998  const DataLayout &dataLayout) {
999  return DeletionKind::Delete;
1000 }
1001 
1002 LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
1003  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1004  const DataLayout &dataLayout) {
1005  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1006 }
1007 
1008 bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
1009  SmallPtrSetImpl<Attribute> &usedIndices,
1010  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1011  const DataLayout &dataLayout) {
1012  if (&slot.elemType.getDialect() != getOperation()->getDialect())
1013  return false;
1014 
1015  if (getIsVolatile())
1016  return false;
1017 
1018  if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1019  return false;
1020 
1021  if (!areAllIndicesI32(slot))
1022  return false;
1023 
1024  return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
1025 }
1026 
1027 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
1029  RewriterBase &rewriter,
1030  const DataLayout &dataLayout) {
1031  std::optional<DenseMap<Attribute, Type>> types =
1032  cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
1033 
1034  IntegerAttr memsetLenAttr;
1035  bool successfulMatch =
1036  matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
1037  (void)successfulMatch;
1038  assert(successfulMatch);
1039 
1040  bool packed = false;
1041  if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
1042  packed = structType.isPacked();
1043 
1044  Type i32 = IntegerType::get(getContext(), 32);
1045  uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
1046  uint64_t covered = 0;
1047  for (size_t i = 0; i < types->size(); i++) {
1048  // Create indices on the fly to get elements in the right order.
1049  Attribute index = IntegerAttr::get(i32, i);
1050  Type elemType = types->at(index);
1051  uint64_t typeSize = dataLayout.getTypeSize(elemType);
1052 
1053  if (!packed)
1054  covered =
1055  llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
1056 
1057  if (covered >= memsetLen)
1058  break;
1059 
1060  // If this subslot is used, apply a new memset to it.
1061  // Otherwise, only compute its offset within the original memset.
1062  if (subslots.contains(index)) {
1063  uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
1064 
1065  Value newMemsetSizeValue =
1066  rewriter
1067  .create<LLVM::ConstantOp>(
1068  getLen().getLoc(),
1069  IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
1070  .getResult();
1071 
1072  rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
1073  getVal(), newMemsetSizeValue,
1074  getIsVolatile());
1075  }
1076 
1077  covered += typeSize;
1078  }
1079 
1080  return DeletionKind::Delete;
1081 }
1082 
1083 //===----------------------------------------------------------------------===//
1084 // Interfaces for memcpy/memmove
1085 //===----------------------------------------------------------------------===//
1086 
1087 template <class MemcpyLike>
1088 static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
1089  return op.getSrc() == slot.ptr;
1090 }
1091 
1092 template <class MemcpyLike>
1093 static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
1094  return op.getDst() == slot.ptr;
1095 }
1096 
1097 template <class MemcpyLike>
1098 static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
1099  RewriterBase &rewriter) {
1100  return rewriter.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
1101 }
1102 
1103 template <class MemcpyLike>
1104 static bool
1105 memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
1106  const SmallPtrSetImpl<OpOperand *> &blockingUses,
1107  SmallVectorImpl<OpOperand *> &newBlockingUses,
1108  const DataLayout &dataLayout) {
1109  // If source and destination are the same, memcpy behavior is undefined and
1110  // memmove is a no-op. Because there is no memory change happening here,
1111  // simplifying such operations is left to canonicalization.
1112  if (op.getDst() == op.getSrc())
1113  return false;
1114 
1115  if (op.getIsVolatile())
1116  return false;
1117 
1118  return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
1119 }
1120 
1121 template <class MemcpyLike>
1122 static DeletionKind
1123 memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
1124  const SmallPtrSetImpl<OpOperand *> &blockingUses,
1125  RewriterBase &rewriter, Value reachingDefinition) {
1126  if (op.loadsFrom(slot))
1127  rewriter.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition,
1128  op.getDst());
1129  return DeletionKind::Delete;
1130 }
1131 
1132 template <class MemcpyLike>
1133 static LogicalResult
1134 memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
1135  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
1136  DataLayout dataLayout = DataLayout::closest(op);
1137  // While rewiring memcpy-like intrinsics only supports full copies, partial
1138  // copies are still safe accesses so it is enough to only check for writes
1139  // within bounds.
1140  return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
1141 }
1142 
1143 template <class MemcpyLike>
1144 static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1145  SmallPtrSetImpl<Attribute> &usedIndices,
1146  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1147  const DataLayout &dataLayout) {
1148  if (op.getIsVolatile())
1149  return false;
1150 
1151  if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1152  return false;
1153 
1154  if (!areAllIndicesI32(slot))
1155  return false;
1156 
1157  // Only full copies are supported.
1158  if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
1159  return false;
1160 
1161  if (op.getSrc() == slot.ptr)
1162  for (Attribute index : llvm::make_first_range(slot.elementPtrs))
1163  usedIndices.insert(index);
1164 
1165  return true;
1166 }
1167 
1168 namespace {
1169 
1170 template <class MemcpyLike>
1171 void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
1172  MemcpyLike toReplace, Value dst, Value src,
1173  Type toCpy, bool isVolatile) {
1174  Value memcpySize = rewriter.create<LLVM::ConstantOp>(
1175  toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
1176  layout.getTypeSize(toCpy)));
1177  rewriter.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
1178  isVolatile);
1179 }
1180 
1181 template <>
1182 void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
1183  LLVM::MemcpyInlineOp toReplace, Value dst,
1184  Value src, Type toCpy, bool isVolatile) {
1185  Type lenType = IntegerType::get(toReplace->getContext(),
1186  toReplace.getLen().getBitWidth());
1187  rewriter.create<LLVM::MemcpyInlineOp>(
1188  toReplace.getLoc(), dst, src,
1189  IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
1190 }
1191 
1192 } // namespace
1193 
1194 /// Rewires a memcpy-like operation. Only copies to or from the full slot are
1195 /// supported.
1196 template <class MemcpyLike>
1197 static DeletionKind
1198 memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1199  DenseMap<Attribute, MemorySlot> &subslots, RewriterBase &rewriter,
1200  const DataLayout &dataLayout) {
1201  if (subslots.empty())
1202  return DeletionKind::Delete;
1203 
1204  assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
1205  bool isDst = slot.ptr == op.getDst();
1206 
1207 #ifndef NDEBUG
1208  size_t slotsTreated = 0;
1209 #endif
1210 
1211  // It was previously checked that index types are consistent, so this type can
1212  // be fetched now.
1213  Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
1214  for (size_t i = 0, e = slot.elementPtrs.size(); i != e; i++) {
1215  Attribute index = IntegerAttr::get(indexType, i);
1216  if (!subslots.contains(index))
1217  continue;
1218  const MemorySlot &subslot = subslots.at(index);
1219 
1220 #ifndef NDEBUG
1221  slotsTreated++;
1222 #endif
1223 
1224  // First get a pointer to the equivalent of this subslot from the source
1225  // pointer.
1226  SmallVector<LLVM::GEPArg> gepIndices{
1227  0, static_cast<int32_t>(
1228  cast<IntegerAttr>(index).getValue().getZExtValue())};
1229  Value subslotPtrInOther = rewriter.create<LLVM::GEPOp>(
1231  isDst ? op.getSrc() : op.getDst(), gepIndices);
1232 
1233  // Then create a new memcpy out of this source pointer.
1234  createMemcpyLikeToReplace(rewriter, dataLayout, op,
1235  isDst ? subslot.ptr : subslotPtrInOther,
1236  isDst ? subslotPtrInOther : subslot.ptr,
1237  subslot.elemType, op.getIsVolatile());
1238  }
1239 
1240  assert(subslots.size() == slotsTreated);
1241 
1242  return DeletionKind::Delete;
1243 }
1244 
1245 bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
1246  return memcpyLoadsFrom(*this, slot);
1247 }
1248 
1249 bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1250  return memcpyStoresTo(*this, slot);
1251 }
1252 
1253 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1254  Value reachingDef,
1255  const DataLayout &dataLayout) {
1256  return memcpyGetStored(*this, slot, rewriter);
1257 }
1258 
1259 bool LLVM::MemcpyOp::canUsesBeRemoved(
1260  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1261  SmallVectorImpl<OpOperand *> &newBlockingUses,
1262  const DataLayout &dataLayout) {
1263  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1264  dataLayout);
1265 }
1266 
1267 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
1268  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1269  RewriterBase &rewriter, Value reachingDefinition,
1270  const DataLayout &dataLayout) {
1271  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1272  reachingDefinition);
1273 }
1274 
1275 LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
1276  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1277  const DataLayout &dataLayout) {
1278  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1279 }
1280 
1281 bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
1282  SmallPtrSetImpl<Attribute> &usedIndices,
1283  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1284  const DataLayout &dataLayout) {
1285  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1286  dataLayout);
1287 }
1288 
1289 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
1291  RewriterBase &rewriter,
1292  const DataLayout &dataLayout) {
1293  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1294 }
1295 
1296 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
1297  return memcpyLoadsFrom(*this, slot);
1298 }
1299 
1300 bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1301  return memcpyStoresTo(*this, slot);
1302 }
1303 
1304 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1305  RewriterBase &rewriter, Value reachingDef,
1306  const DataLayout &dataLayout) {
1307  return memcpyGetStored(*this, slot, rewriter);
1308 }
1309 
1310 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1311  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1312  SmallVectorImpl<OpOperand *> &newBlockingUses,
1313  const DataLayout &dataLayout) {
1314  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1315  dataLayout);
1316 }
1317 
1318 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
1319  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1320  RewriterBase &rewriter, Value reachingDefinition,
1321  const DataLayout &dataLayout) {
1322  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1323  reachingDefinition);
1324 }
1325 
1326 LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
1327  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1328  const DataLayout &dataLayout) {
1329  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1330 }
1331 
1332 bool LLVM::MemcpyInlineOp::canRewire(
1333  const DestructurableMemorySlot &slot,
1334  SmallPtrSetImpl<Attribute> &usedIndices,
1335  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1336  const DataLayout &dataLayout) {
1337  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1338  dataLayout);
1339 }
1340 
1342 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
1344  RewriterBase &rewriter,
1345  const DataLayout &dataLayout) {
1346  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1347 }
1348 
1349 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
1350  return memcpyLoadsFrom(*this, slot);
1351 }
1352 
1353 bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1354  return memcpyStoresTo(*this, slot);
1355 }
1356 
1357 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1358  Value reachingDef,
1359  const DataLayout &dataLayout) {
1360  return memcpyGetStored(*this, slot, rewriter);
1361 }
1362 
1363 bool LLVM::MemmoveOp::canUsesBeRemoved(
1364  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1365  SmallVectorImpl<OpOperand *> &newBlockingUses,
1366  const DataLayout &dataLayout) {
1367  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1368  dataLayout);
1369 }
1370 
1371 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
1372  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1373  RewriterBase &rewriter, Value reachingDefinition,
1374  const DataLayout &dataLayout) {
1375  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1376  reachingDefinition);
1377 }
1378 
1379 LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
1380  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1381  const DataLayout &dataLayout) {
1382  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1383 }
1384 
1385 bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
1386  SmallPtrSetImpl<Attribute> &usedIndices,
1387  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1388  const DataLayout &dataLayout) {
1389  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1390  dataLayout);
1391 }
1392 
1393 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
1395  RewriterBase &rewriter,
1396  const DataLayout &dataLayout) {
1397  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1398 }
1399 
1400 //===----------------------------------------------------------------------===//
1401 // Interfaces for destructurable types
1402 //===----------------------------------------------------------------------===//
1403 
1404 std::optional<DenseMap<Attribute, Type>>
1406  Type i32 = IntegerType::get(getContext(), 32);
1407  DenseMap<Attribute, Type> destructured;
1408  for (const auto &[index, elemType] : llvm::enumerate(getBody()))
1409  destructured.insert({IntegerAttr::get(i32, index), elemType});
1410  return destructured;
1411 }
1412 
1414  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1415  if (!indexAttr || !indexAttr.getType().isInteger(32))
1416  return {};
1417  int32_t indexInt = indexAttr.getInt();
1418  ArrayRef<Type> body = getBody();
1419  if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
1420  return {};
1421  return body[indexInt];
1422 }
1423 
1424 std::optional<DenseMap<Attribute, Type>>
1425 LLVM::LLVMArrayType::getSubelementIndexMap() const {
1426  constexpr size_t maxArraySizeForDestructuring = 16;
1427  if (getNumElements() > maxArraySizeForDestructuring)
1428  return {};
1429  int32_t numElements = getNumElements();
1430 
1431  Type i32 = IntegerType::get(getContext(), 32);
1432  DenseMap<Attribute, Type> destructured;
1433  for (int32_t index = 0; index < numElements; ++index)
1434  destructured.insert({IntegerAttr::get(i32, index), getElementType()});
1435  return destructured;
1436 }
1437 
1439  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1440  if (!indexAttr || !indexAttr.getType().isInteger(32))
1441  return {};
1442  int32_t indexInt = indexAttr.getInt();
1443  if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
1444  return {};
1445  return getElementType();
1446 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context, unsigned size)
Constructs a byte array type of the given size.
static LogicalResult memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed)
static Value createInsertAndCast(RewriterBase &rewriter, Location loc, Value srcValue, Value reachingDef, const DataLayout &dataLayout)
Constructs operations that insert the bits of srcValue into the "beginning" of reachingDef (beginning...
static DeletionKind memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, RewriterBase &rewriter, Value reachingDefinition)
static std::optional< uint64_t > gepToByteOffset(const DataLayout &dataLayout, LLVM::GEPOp gep)
Returns the amount of bytes the provided GEP elements will offset the pointer by.
static bool areAllIndicesI32(const DestructurableMemorySlot &slot)
Checks whether all indices are i32.
static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc, Value val, Type targetType)
Converts a value with an integer type to targetType.
static std::optional< SubslotAccessInfo > getSubslotAccessInfo(const DestructurableMemorySlot &slot, const DataLayout &dataLayout, LLVM::GEPOp gep)
Computes subslot access information for an access into slot with the given offset.
static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool areConversionCompatible(const DataLayout &layout, Type targetType, Type srcType, bool narrowingConversion)
Checks that rhs can be converted to lhs by a sequence of casts and truncations.
static bool forwardToUsers(Operation *op, SmallVectorImpl< OpOperand * > &newBlockingUses)
Conditions the deletion of the operation to the removal of all its uses.
static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot)
static Value createExtractAndCast(RewriterBase &rewriter, Location loc, Value srcValue, Type targetType, const DataLayout &dataLayout)
Constructs operations that convert srcValue into a new value of type targetType.
static bool isSupportedTypeForConversion(Type type)
Checks if type can be used in any kind of conversion sequences.
static bool memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, SmallVectorImpl< OpOperand * > &newBlockingUses, const DataLayout &dataLayout)
static DeletionKind memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, DenseMap< Attribute, MemorySlot > &subslots, RewriterBase &rewriter, const DataLayout &dataLayout)
Rewires a memcpy-like operation.
static bool isBigEndian(const DataLayout &dataLayout)
Checks if dataLayout describes a little endian layout.
static Value castSameSizedTypes(RewriterBase &rewriter, Location loc, Value srcValue, Type targetType, const DataLayout &dataLayout)
Constructs operations that convert srcValue into a new value of type targetType.
static bool hasAllZeroIndices(LLVM::GEPOp gepOp)
static bool isValidAccessType(const MemorySlot &slot, Type accessType, const DataLayout &dataLayout)
Checks if slot can be accessed through the provided access type.
static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val, const DataLayout &dataLayout)
Converts a value to an integer type of the same size.
static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, RewriterBase &rewriter)
static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot, const DataLayout &dataLayout)
Returns whether one can be sure the memory intrinsic does not write outside of the bounds of the give...
static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, SmallPtrSetImpl< Attribute > &usedIndices, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed, const DataLayout &dataLayout)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Attribute getEndianness() const
Returns the specified endianness.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:109
Type getTypeAtIndex(Attribute index)
Returns which type is stored at a given integer index within the struct.
bool isPacked() const
Checks if a struct is packed.
Definition: LLVMTypes.cpp:482
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:490
std::optional< DenseMap< Attribute, Type > > getSubelementIndexMap()
Destructs the struct into its indexed field types.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr int kGEPConstantBitWidth
Bit-width of a 'GEPConstantIndex' within GEPArg.
Definition: LLVMDialect.h:65
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the corresponding subelement type.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.