MLIR  20.0.0git
LLVMMemorySlot.cpp
Go to the documentation of this file.
1 //===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MemorySlot-related interfaces for LLVM dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define DEBUG_TYPE "sroa"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Interfaces for AllocaOp
29 //===----------------------------------------------------------------------===//
30 
31 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
32  if (!getOperation()->getBlock()->isEntryBlock())
33  return {};
34 
35  return {MemorySlot{getResult(), getElemType()}};
36 }
37 
38 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
39  OpBuilder &builder) {
40  return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
41 }
42 
43 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
44  BlockArgument argument,
45  OpBuilder &builder) {
46  for (Operation *user : getOperation()->getUsers())
47  if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48  builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49  declareOp.getVarInfo(),
50  declareOp.getLocationExpr());
51 }
52 
53 std::optional<PromotableAllocationOpInterface>
54 LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
55  Value defaultValue,
56  OpBuilder &builder) {
57  if (defaultValue && defaultValue.use_empty())
58  defaultValue.getDefiningOp()->erase();
59  this->erase();
60  return std::nullopt;
61 }
62 
63 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
64  if (!mlir::matchPattern(getArraySize(), m_One()))
65  return {};
66 
67  auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType());
68  if (!destructurable)
69  return {};
70 
71  std::optional<DenseMap<Attribute, Type>> destructuredType =
72  destructurable.getSubelementIndexMap();
73  if (!destructuredType)
74  return {};
75 
76  return {DestructurableMemorySlot{{getResult(), getElemType()},
77  *destructuredType}};
78 }
79 
80 DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
81  const DestructurableMemorySlot &slot,
82  const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
84  assert(slot.ptr == getResult());
85  builder.setInsertionPointAfter(*this);
86 
87  auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
89  for (Attribute index : usedIndices) {
90  Type elemType = destructurableType.getTypeAtIndex(index);
91  assert(elemType && "used index must exist");
92  auto subAlloca = builder.create<LLVM::AllocaOp>(
93  getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
94  getArraySize());
95  newAllocators.push_back(subAlloca);
96  slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
97  }
98 
99  return slotMap;
100 }
101 
102 std::optional<DestructurableAllocationOpInterface>
103 LLVM::AllocaOp::handleDestructuringComplete(
104  const DestructurableMemorySlot &slot, OpBuilder &builder) {
105  assert(slot.ptr == getResult());
106  this->erase();
107  return std::nullopt;
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Interfaces for LoadOp/StoreOp
112 //===----------------------------------------------------------------------===//
113 
114 bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
115  return getAddr() == slot.ptr;
116 }
117 
118 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
119 
120 Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
121  Value reachingDef, const DataLayout &dataLayout) {
122  llvm_unreachable("getStored should not be called on LoadOp");
123 }
124 
125 bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
126 
127 bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
128  return getAddr() == slot.ptr;
129 }
130 
131 /// Checks if `type` can be used in any kind of conversion sequences.
133  // Aggregate types are not bitcastable.
134  if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
135  return false;
136 
137  // LLVM vector types are only used for either pointers or target specific
138  // types. These types cannot be casted in the general case, thus the memory
139  // optimizations do not support them.
140  if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
141  return false;
142 
143  // Scalable types are not supported.
144  if (auto vectorType = dyn_cast<VectorType>(type))
145  return !vectorType.isScalable();
146  return true;
147 }
148 
149 /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
150 /// truncations. Checks for narrowing or widening conversion compatibility
151 /// depending on `narrowingConversion`.
152 static bool areConversionCompatible(const DataLayout &layout, Type targetType,
153  Type srcType, bool narrowingConversion) {
154  if (targetType == srcType)
155  return true;
156 
157  if (!isSupportedTypeForConversion(targetType) ||
159  return false;
160 
161  uint64_t targetSize = layout.getTypeSize(targetType);
162  uint64_t srcSize = layout.getTypeSize(srcType);
163 
164  // Pointer casts will only be sane when the bitsize of both pointer types is
165  // the same.
166  if (isa<LLVM::LLVMPointerType>(targetType) &&
167  isa<LLVM::LLVMPointerType>(srcType))
168  return targetSize == srcSize;
169 
170  if (narrowingConversion)
171  return targetSize <= srcSize;
172  return targetSize >= srcSize;
173 }
174 
175 /// Checks if `dataLayout` describes a little endian layout.
176 static bool isBigEndian(const DataLayout &dataLayout) {
177  auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
178  return endiannessStr && endiannessStr == "big";
179 }
180 
181 /// Converts a value to an integer type of the same size.
182 /// Assumes that the type can be converted.
183 static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val,
184  const DataLayout &dataLayout) {
185  Type type = val.getType();
186  assert(isSupportedTypeForConversion(type) &&
187  "expected value to have a convertible type");
188 
189  if (isa<IntegerType>(type))
190  return val;
191 
192  uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
193  IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize);
194 
195  if (isa<LLVM::LLVMPointerType>(type))
196  return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
197  return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
198 }
199 
200 /// Converts a value with an integer type to `targetType`.
202  Value val, Type targetType) {
203  assert(isa<IntegerType>(val.getType()) &&
204  "expected value to have an integer type");
205  assert(isSupportedTypeForConversion(targetType) &&
206  "expected the target type to be supported for conversions");
207  if (val.getType() == targetType)
208  return val;
209  if (isa<LLVM::LLVMPointerType>(targetType))
210  return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
211  return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
212 }
213 
214 /// Constructs operations that convert `srcValue` into a new value of type
215 /// `targetType`. Assumes the types have the same bitsize.
217  Value srcValue, Type targetType,
218  const DataLayout &dataLayout) {
219  Type srcType = srcValue.getType();
220  assert(areConversionCompatible(dataLayout, targetType, srcType,
221  /*narrowingConversion=*/true) &&
222  "expected that the compatibility was checked before");
223 
224  // Nothing has to be done if the types are already the same.
225  if (srcType == targetType)
226  return srcValue;
227 
228  // In the special case of casting one pointer to another, we want to generate
229  // an address space cast. Bitcasts of pointers are not allowed and using
230  // pointer to integer conversions are not equivalent due to the loss of
231  // provenance.
232  if (isa<LLVM::LLVMPointerType>(targetType) &&
233  isa<LLVM::LLVMPointerType>(srcType))
234  return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
235  srcValue);
236 
237  // For all other castable types, casting through integers is necessary.
238  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
239  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
240 }
241 
242 /// Constructs operations that convert `srcValue` into a new value of type
243 /// `targetType`. Performs bit-level extraction if the source type is larger
244 /// than the target type. Assumes that this conversion is possible.
246  Value srcValue, Type targetType,
247  const DataLayout &dataLayout) {
248  // Get the types of the source and target values.
249  Type srcType = srcValue.getType();
250  assert(areConversionCompatible(dataLayout, targetType, srcType,
251  /*narrowingConversion=*/true) &&
252  "expected that the compatibility was checked before");
253 
254  uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
255  uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
256  if (srcTypeSize == targetTypeSize)
257  return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout);
258 
259  // First, cast the value to a same-sized integer type.
260  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
261 
262  // Truncate the integer if the size of the target is less than the value.
263  if (isBigEndian(dataLayout)) {
264  uint64_t shiftAmount = srcTypeSize - targetTypeSize;
265  auto shiftConstant = builder.create<LLVM::ConstantOp>(
266  loc, builder.getIntegerAttr(srcType, shiftAmount));
267  replacement =
268  builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
269  }
270 
271  replacement = builder.create<LLVM::TruncOp>(
272  loc, builder.getIntegerType(targetTypeSize), replacement);
273 
274  // Now cast the integer to the actual target type if required.
275  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
276 }
277 
278 /// Constructs operations that insert the bits of `srcValue` into the
279 /// "beginning" of `reachingDef` (beginning is endianness dependent).
280 /// Assumes that this conversion is possible.
282  Value srcValue, Value reachingDef,
283  const DataLayout &dataLayout) {
284 
285  assert(areConversionCompatible(dataLayout, reachingDef.getType(),
286  srcValue.getType(),
287  /*narrowingConversion=*/false) &&
288  "expected that the compatibility was checked before");
289  uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
290  uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
291  if (slotTypeSize == valueTypeSize)
292  return castSameSizedTypes(builder, loc, srcValue, reachingDef.getType(),
293  dataLayout);
294 
295  // In the case where the store only overwrites parts of the memory,
296  // bit fiddling is required to construct the new value.
297 
298  // First convert both values to integers of the same size.
299  Value defAsInt = castToSameSizedInt(builder, loc, reachingDef, dataLayout);
300  Value valueAsInt = castToSameSizedInt(builder, loc, srcValue, dataLayout);
301  // Extend the value to the size of the reaching definition.
302  valueAsInt =
303  builder.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
304  uint64_t sizeDifference = slotTypeSize - valueTypeSize;
305  if (isBigEndian(dataLayout)) {
306  // On big endian systems, a store to the base pointer overwrites the most
307  // significant bits. To accomodate for this, the stored value needs to be
308  // shifted into the according position.
309  Value bigEndianShift = builder.create<LLVM::ConstantOp>(
310  loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference));
311  valueAsInt =
312  builder.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
313  }
314 
315  // Construct the mask that is used to erase the bits that are overwritten by
316  // the store.
317  APInt maskValue;
318  if (isBigEndian(dataLayout)) {
319  // Build a mask that has the most significant bits set to zero.
320  // Note: This is the same as 2^sizeDifference - 1
321  maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
322  } else {
323  // Build a mask that has the least significant bits set to zero.
324  // Note: This is the same as -(2^valueTypeSize)
325  maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
326  maskValue.flipAllBits();
327  }
328 
329  // Mask out the affected bits ...
330  Value mask = builder.create<LLVM::ConstantOp>(
331  loc, builder.getIntegerAttr(defAsInt.getType(), maskValue));
332  Value masked = builder.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
333 
334  // ... and combine the result with the new value.
335  Value combined = builder.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
336 
337  return castIntValueToSameSizedType(builder, loc, combined,
338  reachingDef.getType());
339 }
340 
341 Value LLVM::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
342  Value reachingDef,
343  const DataLayout &dataLayout) {
344  assert(reachingDef && reachingDef.getType() == slot.elemType &&
345  "expected the reaching definition's type to match the slot's type");
346  return createInsertAndCast(builder, getLoc(), getValue(), reachingDef,
347  dataLayout);
348 }
349 
350 bool LLVM::LoadOp::canUsesBeRemoved(
351  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
352  SmallVectorImpl<OpOperand *> &newBlockingUses,
353  const DataLayout &dataLayout) {
354  if (blockingUses.size() != 1)
355  return false;
356  Value blockingUse = (*blockingUses.begin())->get();
357  // If the blocking use is the slot ptr itself, there will be enough
358  // context to reconstruct the result of the load at removal time, so it can
359  // be removed (provided it is not volatile).
360  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
361  areConversionCompatible(dataLayout, getResult().getType(),
362  slot.elemType, /*narrowingConversion=*/true) &&
363  !getVolatile_();
364 }
365 
366 DeletionKind LLVM::LoadOp::removeBlockingUses(
367  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
368  OpBuilder &builder, Value reachingDefinition,
369  const DataLayout &dataLayout) {
370  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
371  // pointer.
372  Value newResult = createExtractAndCast(builder, getLoc(), reachingDefinition,
373  getResult().getType(), dataLayout);
374  getResult().replaceAllUsesWith(newResult);
375  return DeletionKind::Delete;
376 }
377 
378 bool LLVM::StoreOp::canUsesBeRemoved(
379  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
380  SmallVectorImpl<OpOperand *> &newBlockingUses,
381  const DataLayout &dataLayout) {
382  if (blockingUses.size() != 1)
383  return false;
384  Value blockingUse = (*blockingUses.begin())->get();
385  // If the blocking use is the slot ptr itself, dropping the store is
386  // fine, provided we are currently promoting its target value. Don't allow a
387  // store OF the slot pointer, only INTO the slot pointer.
388  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
389  getValue() != slot.ptr &&
390  areConversionCompatible(dataLayout, slot.elemType,
391  getValue().getType(),
392  /*narrowingConversion=*/false) &&
393  !getVolatile_();
394 }
395 
396 DeletionKind LLVM::StoreOp::removeBlockingUses(
397  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
398  OpBuilder &builder, Value reachingDefinition,
399  const DataLayout &dataLayout) {
400  return DeletionKind::Delete;
401 }
402 
403 /// Checks if `slot` can be accessed through the provided access type.
404 static bool isValidAccessType(const MemorySlot &slot, Type accessType,
405  const DataLayout &dataLayout) {
406  return dataLayout.getTypeSize(accessType) <=
407  dataLayout.getTypeSize(slot.elemType);
408 }
409 
410 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
411  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
412  const DataLayout &dataLayout) {
413  return success(getAddr() != slot.ptr ||
414  isValidAccessType(slot, getType(), dataLayout));
415 }
416 
417 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
418  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
419  const DataLayout &dataLayout) {
420  return success(getAddr() != slot.ptr ||
421  isValidAccessType(slot, getValue().getType(), dataLayout));
422 }
423 
424 /// Returns the subslot's type at the requested index.
426  Attribute index) {
427  auto subelementIndexMap =
428  cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
429  if (!subelementIndexMap)
430  return {};
431  assert(!subelementIndexMap->empty());
432 
433  // Note: Returns a null-type when no entry was found.
434  return subelementIndexMap->lookup(index);
435 }
436 
437 bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
438  SmallPtrSetImpl<Attribute> &usedIndices,
439  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
440  const DataLayout &dataLayout) {
441  if (getVolatile_())
442  return false;
443 
444  // A load always accesses the first element of the destructured slot.
445  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
446  Type subslotType = getTypeAtIndex(slot, index);
447  if (!subslotType)
448  return false;
449 
450  // The access can only be replaced when the subslot is read within its bounds.
451  if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
452  return false;
453 
454  usedIndices.insert(index);
455  return true;
456 }
457 
458 DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
460  OpBuilder &builder,
461  const DataLayout &dataLayout) {
462  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
463  auto it = subslots.find(index);
464  assert(it != subslots.end());
465 
466  getAddrMutable().set(it->getSecond().ptr);
467  return DeletionKind::Keep;
468 }
469 
470 bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
471  SmallPtrSetImpl<Attribute> &usedIndices,
472  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
473  const DataLayout &dataLayout) {
474  if (getVolatile_())
475  return false;
476 
477  // Storing the pointer to memory cannot be dealt with.
478  if (getValue() == slot.ptr)
479  return false;
480 
481  // A store always accesses the first element of the destructured slot.
482  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
483  Type subslotType = getTypeAtIndex(slot, index);
484  if (!subslotType)
485  return false;
486 
487  // The access can only be replaced when the subslot is read within its bounds.
488  if (dataLayout.getTypeSize(getValue().getType()) >
489  dataLayout.getTypeSize(subslotType))
490  return false;
491 
492  usedIndices.insert(index);
493  return true;
494 }
495 
496 DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
498  OpBuilder &builder,
499  const DataLayout &dataLayout) {
500  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
501  auto it = subslots.find(index);
502  assert(it != subslots.end());
503 
504  getAddrMutable().set(it->getSecond().ptr);
505  return DeletionKind::Keep;
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // Interfaces for discardable OPs
510 //===----------------------------------------------------------------------===//
511 
512 /// Conditions the deletion of the operation to the removal of all its uses.
513 static bool forwardToUsers(Operation *op,
514  SmallVectorImpl<OpOperand *> &newBlockingUses) {
515  for (Value result : op->getResults())
516  for (OpOperand &use : result.getUses())
517  newBlockingUses.push_back(&use);
518  return true;
519 }
520 
521 bool LLVM::BitcastOp::canUsesBeRemoved(
522  const SmallPtrSetImpl<OpOperand *> &blockingUses,
523  SmallVectorImpl<OpOperand *> &newBlockingUses,
524  const DataLayout &dataLayout) {
525  return forwardToUsers(*this, newBlockingUses);
526 }
527 
528 DeletionKind LLVM::BitcastOp::removeBlockingUses(
529  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
530  return DeletionKind::Delete;
531 }
532 
533 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
534  const SmallPtrSetImpl<OpOperand *> &blockingUses,
535  SmallVectorImpl<OpOperand *> &newBlockingUses,
536  const DataLayout &dataLayout) {
537  return forwardToUsers(*this, newBlockingUses);
538 }
539 
540 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
541  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
542  return DeletionKind::Delete;
543 }
544 
545 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
546  const SmallPtrSetImpl<OpOperand *> &blockingUses,
547  SmallVectorImpl<OpOperand *> &newBlockingUses,
548  const DataLayout &dataLayout) {
549  return true;
550 }
551 
552 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
553  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
554  return DeletionKind::Delete;
555 }
556 
557 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
558  const SmallPtrSetImpl<OpOperand *> &blockingUses,
559  SmallVectorImpl<OpOperand *> &newBlockingUses,
560  const DataLayout &dataLayout) {
561  return true;
562 }
563 
564 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
565  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
566  return DeletionKind::Delete;
567 }
568 
569 bool LLVM::InvariantStartOp::canUsesBeRemoved(
570  const SmallPtrSetImpl<OpOperand *> &blockingUses,
571  SmallVectorImpl<OpOperand *> &newBlockingUses,
572  const DataLayout &dataLayout) {
573  return true;
574 }
575 
576 DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
577  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
578  return DeletionKind::Delete;
579 }
580 
581 bool LLVM::InvariantEndOp::canUsesBeRemoved(
582  const SmallPtrSetImpl<OpOperand *> &blockingUses,
583  SmallVectorImpl<OpOperand *> &newBlockingUses,
584  const DataLayout &dataLayout) {
585  return true;
586 }
587 
588 DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
589  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
590  return DeletionKind::Delete;
591 }
592 
593 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
594  const SmallPtrSetImpl<OpOperand *> &blockingUses,
595  SmallVectorImpl<OpOperand *> &newBlockingUses,
596  const DataLayout &dataLayout) {
597  return true;
598 }
599 
600 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
601  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
602  return DeletionKind::Delete;
603 }
604 
605 bool LLVM::DbgValueOp::canUsesBeRemoved(
606  const SmallPtrSetImpl<OpOperand *> &blockingUses,
607  SmallVectorImpl<OpOperand *> &newBlockingUses,
608  const DataLayout &dataLayout) {
609  // There is only one operand that we can remove the use of.
610  if (blockingUses.size() != 1)
611  return false;
612 
613  return (*blockingUses.begin())->get() == getValue();
614 }
615 
616 DeletionKind LLVM::DbgValueOp::removeBlockingUses(
617  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
618  // builder by default is after '*this', but we need it before '*this'.
619  builder.setInsertionPoint(*this);
620 
621  // Rather than dropping the debug value, replace it with undef to preserve the
622  // debug local variable info. This allows the debugger to inform the user that
623  // the variable has been optimized out.
624  auto undef =
625  builder.create<UndefOp>(getValue().getLoc(), getValue().getType());
626  getValueMutable().assign(undef);
627  return DeletionKind::Keep;
628 }
629 
630 bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
631 
632 void LLVM::DbgDeclareOp::visitReplacedValues(
633  ArrayRef<std::pair<Operation *, Value>> definitions, OpBuilder &builder) {
634  for (auto [op, value] : definitions) {
635  builder.setInsertionPointAfter(op);
636  builder.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
637  getLocationExpr());
638  }
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // Interfaces for GEPOp
643 //===----------------------------------------------------------------------===//
644 
645 static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
646  return llvm::all_of(gepOp.getIndices(), [](auto index) {
647  auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
648  return indexAttr && indexAttr.getValue() == 0;
649  });
650 }
651 
652 bool LLVM::GEPOp::canUsesBeRemoved(
653  const SmallPtrSetImpl<OpOperand *> &blockingUses,
654  SmallVectorImpl<OpOperand *> &newBlockingUses,
655  const DataLayout &dataLayout) {
656  // GEP can be removed as long as it is a no-op and its users can be removed.
657  if (!hasAllZeroIndices(*this))
658  return false;
659  return forwardToUsers(*this, newBlockingUses);
660 }
661 
662 DeletionKind LLVM::GEPOp::removeBlockingUses(
663  const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
664  return DeletionKind::Delete;
665 }
666 
667 /// Returns the amount of bytes the provided GEP elements will offset the
668 /// pointer by. Returns nullopt if no constant offset could be computed.
669 static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
670  LLVM::GEPOp gep) {
671  // Collects all indices.
672  SmallVector<uint64_t> indices;
673  for (auto index : gep.getIndices()) {
674  auto constIndex = dyn_cast<IntegerAttr>(index);
675  if (!constIndex)
676  return {};
677  int64_t gepIndex = constIndex.getInt();
678  // Negative indices are not supported.
679  if (gepIndex < 0)
680  return {};
681  indices.push_back(gepIndex);
682  }
683 
684  Type currentType = gep.getElemType();
685  uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
686 
687  for (uint64_t index : llvm::drop_begin(indices)) {
688  bool shouldCancel =
689  TypeSwitch<Type, bool>(currentType)
690  .Case([&](LLVM::LLVMArrayType arrayType) {
691  offset +=
692  index * dataLayout.getTypeSize(arrayType.getElementType());
693  currentType = arrayType.getElementType();
694  return false;
695  })
696  .Case([&](LLVM::LLVMStructType structType) {
697  ArrayRef<Type> body = structType.getBody();
698  assert(index < body.size() && "expected valid struct indexing");
699  for (uint32_t i : llvm::seq(index)) {
700  if (!structType.isPacked())
701  offset = llvm::alignTo(
702  offset, dataLayout.getTypeABIAlignment(body[i]));
703  offset += dataLayout.getTypeSize(body[i]);
704  }
705 
706  // Align for the current type as well.
707  if (!structType.isPacked())
708  offset = llvm::alignTo(
709  offset, dataLayout.getTypeABIAlignment(body[index]));
710  currentType = body[index];
711  return false;
712  })
713  .Default([&](Type type) {
714  LLVM_DEBUG(llvm::dbgs()
715  << "[sroa] Unsupported type for offset computations"
716  << type << "\n");
717  return true;
718  });
719 
720  if (shouldCancel)
721  return std::nullopt;
722  }
723 
724  return offset;
725 }
726 
727 namespace {
728 /// A struct that stores both the index into the aggregate type of the slot as
729 /// well as the corresponding byte offset in memory.
730 struct SubslotAccessInfo {
731  /// The parent slot's index that the access falls into.
732  uint32_t index;
733  /// The offset into the subslot of the access.
734  uint64_t subslotOffset;
735 };
736 } // namespace
737 
738 /// Computes subslot access information for an access into `slot` with the given
739 /// offset.
740 /// Returns nullopt when the offset is out-of-bounds or when the access is into
741 /// the padding of `slot`.
742 static std::optional<SubslotAccessInfo>
744  const DataLayout &dataLayout, LLVM::GEPOp gep) {
745  std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
746  if (!offset)
747  return {};
748 
749  // Helper to check that a constant index is in the bounds of the GEP index
750  // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
751  // this additional check is necessary.
752  auto isOutOfBoundsGEPIndex = [](uint64_t index) {
753  return index >= (1 << LLVM::kGEPConstantBitWidth);
754  };
755 
756  Type type = slot.elemType;
757  if (*offset >= dataLayout.getTypeSize(type))
758  return {};
760  .Case([&](LLVM::LLVMArrayType arrayType)
761  -> std::optional<SubslotAccessInfo> {
762  // Find which element of the array contains the offset.
763  uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
764  uint64_t index = *offset / elemSize;
765  if (isOutOfBoundsGEPIndex(index))
766  return {};
767  return SubslotAccessInfo{static_cast<uint32_t>(index),
768  *offset - (index * elemSize)};
769  })
770  .Case([&](LLVM::LLVMStructType structType)
771  -> std::optional<SubslotAccessInfo> {
772  uint64_t distanceToStart = 0;
773  // Walk over the elements of the struct to find in which of
774  // them the offset is.
775  for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
776  uint64_t elemSize = dataLayout.getTypeSize(elem);
777  if (!structType.isPacked()) {
778  distanceToStart = llvm::alignTo(
779  distanceToStart, dataLayout.getTypeABIAlignment(elem));
780  // If the offset is in padding, cancel the rewrite.
781  if (offset < distanceToStart)
782  return {};
783  }
784 
785  if (offset < distanceToStart + elemSize) {
786  if (isOutOfBoundsGEPIndex(index))
787  return {};
788  // The offset is within this element, stop iterating the
789  // struct and return the index.
790  return SubslotAccessInfo{static_cast<uint32_t>(index),
791  *offset - distanceToStart};
792  }
793 
794  // The offset is not within this element, continue walking
795  // over the struct.
796  distanceToStart += elemSize;
797  }
798 
799  return {};
800  });
801 }
802 
803 /// Constructs a byte array type of the given size.
804 static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
805  unsigned size) {
806  auto byteType = IntegerType::get(context, 8);
807  return LLVM::LLVMArrayType::get(context, byteType, size);
808 }
809 
810 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
811  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
812  const DataLayout &dataLayout) {
813  if (getBase() != slot.ptr)
814  return success();
815  std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
816  if (!gepOffset)
817  return failure();
818  uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
819  // Check that the access is strictly inside the slot.
820  if (*gepOffset >= slotSize)
821  return failure();
822  // Every access that remains in bounds of the remaining slot is considered
823  // legal.
824  mustBeSafelyUsed.emplace_back<MemorySlot>(
825  {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
826  return success();
827 }
828 
829 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
830  SmallPtrSetImpl<Attribute> &usedIndices,
831  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
832  const DataLayout &dataLayout) {
833  if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
834  return false;
835 
836  if (getBase() != slot.ptr)
837  return false;
838  std::optional<SubslotAccessInfo> accessInfo =
839  getSubslotAccessInfo(slot, dataLayout, *this);
840  if (!accessInfo)
841  return false;
842  auto indexAttr =
843  IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
844  assert(slot.subelementTypes.contains(indexAttr));
845  usedIndices.insert(indexAttr);
846 
847  // The remainder of the subslot should be accesses in-bounds. Thus, we create
848  // a dummy slot with the size of the remainder.
849  Type subslotType = slot.subelementTypes.lookup(indexAttr);
850  uint64_t slotSize = dataLayout.getTypeSize(subslotType);
851  LLVM::LLVMArrayType remainingSlotType =
852  getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
853  mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
854 
855  return true;
856 }
857 
858 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
860  OpBuilder &builder,
861  const DataLayout &dataLayout) {
862  std::optional<SubslotAccessInfo> accessInfo =
863  getSubslotAccessInfo(slot, dataLayout, *this);
864  assert(accessInfo && "expected access info to be checked before");
865  auto indexAttr =
866  IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
867  const MemorySlot &newSlot = subslots.at(indexAttr);
868 
869  auto byteType = IntegerType::get(builder.getContext(), 8);
870  auto newPtr = builder.createOrFold<LLVM::GEPOp>(
871  getLoc(), getResult().getType(), byteType, newSlot.ptr,
872  ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
873  getResult().replaceAllUsesWith(newPtr);
874  return DeletionKind::Delete;
875 }
876 
877 //===----------------------------------------------------------------------===//
878 // Utilities for memory intrinsics
879 //===----------------------------------------------------------------------===//
880 
881 namespace {
882 
883 /// Returns the length of the given memory intrinsic in bytes if it can be known
884 /// at compile-time on a best-effort basis, nothing otherwise.
885 template <class MemIntr>
886 std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
887  APInt memIntrLen;
888  if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
889  return {};
890  if (memIntrLen.getBitWidth() > 64)
891  return {};
892  return memIntrLen.getZExtValue();
893 }
894 
895 /// Returns the length of the given memory intrinsic in bytes if it can be known
896 /// at compile-time on a best-effort basis, nothing otherwise.
897 /// Because MemcpyInlineOp has its length encoded as an attribute, this requires
898 /// specialized handling.
899 template <>
900 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
901  APInt memIntrLen = op.getLen();
902  if (memIntrLen.getBitWidth() > 64)
903  return {};
904  return memIntrLen.getZExtValue();
905 }
906 
907 } // namespace
908 
909 /// Returns whether one can be sure the memory intrinsic does not write outside
910 /// of the bounds of the given slot, on a best-effort basis.
911 template <class MemIntr>
912 static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
913  const DataLayout &dataLayout) {
914  if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
915  op.getDst() != slot.ptr)
916  return false;
917 
918  std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
919  return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(slot.elemType);
920 }
921 
922 /// Checks whether all indices are i32. This is used to check GEPs can index
923 /// into them.
924 static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
925  Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
926  return llvm::all_of(llvm::make_first_range(slot.subelementTypes),
927  [&](Attribute index) {
928  auto intIndex = dyn_cast<IntegerAttr>(index);
929  return intIndex && intIndex.getType() == i32;
930  });
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // Interfaces for memset
935 //===----------------------------------------------------------------------===//
936 
937 bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
938 
939 bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
940  return getDst() == slot.ptr;
941 }
942 
943 Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
944  Value reachingDef,
945  const DataLayout &dataLayout) {
946  // TODO: Support non-integer types.
947  return TypeSwitch<Type, Value>(slot.elemType)
948  .Case([&](IntegerType intType) -> Value {
949  if (intType.getWidth() == 8)
950  return getVal();
951 
952  assert(intType.getWidth() % 8 == 0);
953 
954  // Build the memset integer by repeatedly shifting the value and
955  // or-ing it with the previous value.
956  uint64_t coveredBits = 8;
957  Value currentValue =
958  builder.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
959  while (coveredBits < intType.getWidth()) {
960  Value shiftBy =
961  builder.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
962  Value shifted =
963  builder.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
964  currentValue =
965  builder.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
966  coveredBits *= 2;
967  }
968 
969  return currentValue;
970  })
971  .Default([](Type) -> Value {
972  llvm_unreachable(
973  "getStored should not be called on memset to unsupported type");
974  });
975 }
976 
977 bool LLVM::MemsetOp::canUsesBeRemoved(
978  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
979  SmallVectorImpl<OpOperand *> &newBlockingUses,
980  const DataLayout &dataLayout) {
981  // TODO: Support non-integer types.
982  bool canConvertType =
984  .Case([](IntegerType intType) {
985  return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
986  })
987  .Default([](Type) { return false; });
988  if (!canConvertType)
989  return false;
990 
991  if (getIsVolatile())
992  return false;
993 
994  return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
995 }
996 
997 DeletionKind LLVM::MemsetOp::removeBlockingUses(
998  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
999  OpBuilder &builder, Value reachingDefinition,
1000  const DataLayout &dataLayout) {
1001  return DeletionKind::Delete;
1002 }
1003 
1004 LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
1005  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1006  const DataLayout &dataLayout) {
1007  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1008 }
1009 
1010 bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
1011  SmallPtrSetImpl<Attribute> &usedIndices,
1012  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1013  const DataLayout &dataLayout) {
1014  if (&slot.elemType.getDialect() != getOperation()->getDialect())
1015  return false;
1016 
1017  if (getIsVolatile())
1018  return false;
1019 
1020  if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1021  return false;
1022 
1023  if (!areAllIndicesI32(slot))
1024  return false;
1025 
1026  return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
1027 }
1028 
1029 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
1031  OpBuilder &builder,
1032  const DataLayout &dataLayout) {
1033  std::optional<DenseMap<Attribute, Type>> types =
1034  cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
1035 
1036  IntegerAttr memsetLenAttr;
1037  bool successfulMatch =
1038  matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
1039  (void)successfulMatch;
1040  assert(successfulMatch);
1041 
1042  bool packed = false;
1043  if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
1044  packed = structType.isPacked();
1045 
1046  Type i32 = IntegerType::get(getContext(), 32);
1047  uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
1048  uint64_t covered = 0;
1049  for (size_t i = 0; i < types->size(); i++) {
1050  // Create indices on the fly to get elements in the right order.
1051  Attribute index = IntegerAttr::get(i32, i);
1052  Type elemType = types->at(index);
1053  uint64_t typeSize = dataLayout.getTypeSize(elemType);
1054 
1055  if (!packed)
1056  covered =
1057  llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
1058 
1059  if (covered >= memsetLen)
1060  break;
1061 
1062  // If this subslot is used, apply a new memset to it.
1063  // Otherwise, only compute its offset within the original memset.
1064  if (subslots.contains(index)) {
1065  uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
1066 
1067  Value newMemsetSizeValue =
1068  builder
1069  .create<LLVM::ConstantOp>(
1070  getLen().getLoc(),
1071  IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
1072  .getResult();
1073 
1074  builder.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, getVal(),
1075  newMemsetSizeValue, getIsVolatile());
1076  }
1077 
1078  covered += typeSize;
1079  }
1080 
1081  return DeletionKind::Delete;
1082 }
1083 
1084 //===----------------------------------------------------------------------===//
1085 // Interfaces for memcpy/memmove
1086 //===----------------------------------------------------------------------===//
1087 
1088 template <class MemcpyLike>
1089 static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
1090  return op.getSrc() == slot.ptr;
1091 }
1092 
1093 template <class MemcpyLike>
1094 static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
1095  return op.getDst() == slot.ptr;
1096 }
1097 
1098 template <class MemcpyLike>
1099 static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
1100  OpBuilder &builder) {
1101  return builder.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
1102 }
1103 
1104 template <class MemcpyLike>
1105 static bool
1106 memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
1107  const SmallPtrSetImpl<OpOperand *> &blockingUses,
1108  SmallVectorImpl<OpOperand *> &newBlockingUses,
1109  const DataLayout &dataLayout) {
1110  // If source and destination are the same, memcpy behavior is undefined and
1111  // memmove is a no-op. Because there is no memory change happening here,
1112  // simplifying such operations is left to canonicalization.
1113  if (op.getDst() == op.getSrc())
1114  return false;
1115 
1116  if (op.getIsVolatile())
1117  return false;
1118 
1119  return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
1120 }
1121 
1122 template <class MemcpyLike>
1123 static DeletionKind
1124 memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
1125  const SmallPtrSetImpl<OpOperand *> &blockingUses,
1126  OpBuilder &builder, Value reachingDefinition) {
1127  if (op.loadsFrom(slot))
1128  builder.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, op.getDst());
1129  return DeletionKind::Delete;
1130 }
1131 
1132 template <class MemcpyLike>
1133 static LogicalResult
1134 memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
1135  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
1136  DataLayout dataLayout = DataLayout::closest(op);
1137  // While rewiring memcpy-like intrinsics only supports full copies, partial
1138  // copies are still safe accesses so it is enough to only check for writes
1139  // within bounds.
1140  return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
1141 }
1142 
1143 template <class MemcpyLike>
1144 static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1145  SmallPtrSetImpl<Attribute> &usedIndices,
1146  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1147  const DataLayout &dataLayout) {
1148  if (op.getIsVolatile())
1149  return false;
1150 
1151  if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1152  return false;
1153 
1154  if (!areAllIndicesI32(slot))
1155  return false;
1156 
1157  // Only full copies are supported.
1158  if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
1159  return false;
1160 
1161  if (op.getSrc() == slot.ptr)
1162  for (Attribute index : llvm::make_first_range(slot.subelementTypes))
1163  usedIndices.insert(index);
1164 
1165  return true;
1166 }
1167 
1168 namespace {
1169 
1170 template <class MemcpyLike>
1171 void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
1172  MemcpyLike toReplace, Value dst, Value src,
1173  Type toCpy, bool isVolatile) {
1174  Value memcpySize = builder.create<LLVM::ConstantOp>(
1175  toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
1176  layout.getTypeSize(toCpy)));
1177  builder.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
1178  isVolatile);
1179 }
1180 
1181 template <>
1182 void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
1183  LLVM::MemcpyInlineOp toReplace, Value dst,
1184  Value src, Type toCpy, bool isVolatile) {
1185  Type lenType = IntegerType::get(toReplace->getContext(),
1186  toReplace.getLen().getBitWidth());
1187  builder.create<LLVM::MemcpyInlineOp>(
1188  toReplace.getLoc(), dst, src,
1189  IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
1190 }
1191 
1192 } // namespace
1193 
1194 /// Rewires a memcpy-like operation. Only copies to or from the full slot are
1195 /// supported.
1196 template <class MemcpyLike>
1197 static DeletionKind
1198 memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1199  DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
1200  const DataLayout &dataLayout) {
1201  if (subslots.empty())
1202  return DeletionKind::Delete;
1203 
1204  assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
1205  bool isDst = slot.ptr == op.getDst();
1206 
1207 #ifndef NDEBUG
1208  size_t slotsTreated = 0;
1209 #endif
1210 
1211  // It was previously checked that index types are consistent, so this type can
1212  // be fetched now.
1213  Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
1214  for (size_t i = 0, e = slot.subelementTypes.size(); i != e; i++) {
1215  Attribute index = IntegerAttr::get(indexType, i);
1216  if (!subslots.contains(index))
1217  continue;
1218  const MemorySlot &subslot = subslots.at(index);
1219 
1220 #ifndef NDEBUG
1221  slotsTreated++;
1222 #endif
1223 
1224  // First get a pointer to the equivalent of this subslot from the source
1225  // pointer.
1226  SmallVector<LLVM::GEPArg> gepIndices{
1227  0, static_cast<int32_t>(
1228  cast<IntegerAttr>(index).getValue().getZExtValue())};
1229  Value subslotPtrInOther = builder.create<LLVM::GEPOp>(
1231  isDst ? op.getSrc() : op.getDst(), gepIndices);
1232 
1233  // Then create a new memcpy out of this source pointer.
1234  createMemcpyLikeToReplace(builder, dataLayout, op,
1235  isDst ? subslot.ptr : subslotPtrInOther,
1236  isDst ? subslotPtrInOther : subslot.ptr,
1237  subslot.elemType, op.getIsVolatile());
1238  }
1239 
1240  assert(subslots.size() == slotsTreated);
1241 
1242  return DeletionKind::Delete;
1243 }
1244 
1245 bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
1246  return memcpyLoadsFrom(*this, slot);
1247 }
1248 
1249 bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1250  return memcpyStoresTo(*this, slot);
1251 }
1252 
1253 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1254  Value reachingDef,
1255  const DataLayout &dataLayout) {
1256  return memcpyGetStored(*this, slot, builder);
1257 }
1258 
1259 bool LLVM::MemcpyOp::canUsesBeRemoved(
1260  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1261  SmallVectorImpl<OpOperand *> &newBlockingUses,
1262  const DataLayout &dataLayout) {
1263  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1264  dataLayout);
1265 }
1266 
1267 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
1268  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1269  OpBuilder &builder, Value reachingDefinition,
1270  const DataLayout &dataLayout) {
1271  return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1272  reachingDefinition);
1273 }
1274 
1275 LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
1276  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1277  const DataLayout &dataLayout) {
1278  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1279 }
1280 
1281 bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
1282  SmallPtrSetImpl<Attribute> &usedIndices,
1283  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1284  const DataLayout &dataLayout) {
1285  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1286  dataLayout);
1287 }
1288 
1289 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
1291  OpBuilder &builder,
1292  const DataLayout &dataLayout) {
1293  return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1294 }
1295 
1296 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
1297  return memcpyLoadsFrom(*this, slot);
1298 }
1299 
1300 bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1301  return memcpyStoresTo(*this, slot);
1302 }
1303 
1304 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1305  OpBuilder &builder, Value reachingDef,
1306  const DataLayout &dataLayout) {
1307  return memcpyGetStored(*this, slot, builder);
1308 }
1309 
1310 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1311  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1312  SmallVectorImpl<OpOperand *> &newBlockingUses,
1313  const DataLayout &dataLayout) {
1314  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1315  dataLayout);
1316 }
1317 
1318 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
1319  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1320  OpBuilder &builder, Value reachingDefinition,
1321  const DataLayout &dataLayout) {
1322  return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1323  reachingDefinition);
1324 }
1325 
1326 LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
1327  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1328  const DataLayout &dataLayout) {
1329  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1330 }
1331 
1332 bool LLVM::MemcpyInlineOp::canRewire(
1333  const DestructurableMemorySlot &slot,
1334  SmallPtrSetImpl<Attribute> &usedIndices,
1335  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1336  const DataLayout &dataLayout) {
1337  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1338  dataLayout);
1339 }
1340 
1342 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
1344  OpBuilder &builder, const DataLayout &dataLayout) {
1345  return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1346 }
1347 
1348 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
1349  return memcpyLoadsFrom(*this, slot);
1350 }
1351 
1352 bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1353  return memcpyStoresTo(*this, slot);
1354 }
1355 
1356 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1357  Value reachingDef,
1358  const DataLayout &dataLayout) {
1359  return memcpyGetStored(*this, slot, builder);
1360 }
1361 
1362 bool LLVM::MemmoveOp::canUsesBeRemoved(
1363  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1364  SmallVectorImpl<OpOperand *> &newBlockingUses,
1365  const DataLayout &dataLayout) {
1366  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1367  dataLayout);
1368 }
1369 
1370 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
1371  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1372  OpBuilder &builder, Value reachingDefinition,
1373  const DataLayout &dataLayout) {
1374  return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1375  reachingDefinition);
1376 }
1377 
1378 LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
1379  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1380  const DataLayout &dataLayout) {
1381  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1382 }
1383 
1384 bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
1385  SmallPtrSetImpl<Attribute> &usedIndices,
1386  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1387  const DataLayout &dataLayout) {
1388  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1389  dataLayout);
1390 }
1391 
1392 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
1394  OpBuilder &builder,
1395  const DataLayout &dataLayout) {
1396  return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1397 }
1398 
1399 //===----------------------------------------------------------------------===//
1400 // Interfaces for destructurable types
1401 //===----------------------------------------------------------------------===//
1402 
1403 std::optional<DenseMap<Attribute, Type>>
1405  Type i32 = IntegerType::get(getContext(), 32);
1406  DenseMap<Attribute, Type> destructured;
1407  for (const auto &[index, elemType] : llvm::enumerate(getBody()))
1408  destructured.insert({IntegerAttr::get(i32, index), elemType});
1409  return destructured;
1410 }
1411 
1413  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1414  if (!indexAttr || !indexAttr.getType().isInteger(32))
1415  return {};
1416  int32_t indexInt = indexAttr.getInt();
1417  ArrayRef<Type> body = getBody();
1418  if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
1419  return {};
1420  return body[indexInt];
1421 }
1422 
1423 std::optional<DenseMap<Attribute, Type>>
1424 LLVM::LLVMArrayType::getSubelementIndexMap() const {
1425  constexpr size_t maxArraySizeForDestructuring = 16;
1426  if (getNumElements() > maxArraySizeForDestructuring)
1427  return {};
1428  int32_t numElements = getNumElements();
1429 
1430  Type i32 = IntegerType::get(getContext(), 32);
1431  DenseMap<Attribute, Type> destructured;
1432  for (int32_t index = 0; index < numElements; ++index)
1433  destructured.insert({IntegerAttr::get(i32, index), getElementType()});
1434  return destructured;
1435 }
1436 
1438  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1439  if (!indexAttr || !indexAttr.getType().isInteger(32))
1440  return {};
1441  int32_t indexInt = indexAttr.getInt();
1442  if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
1443  return {};
1444  return getElementType();
1445 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context, unsigned size)
Constructs a byte array type of the given size.
static LogicalResult memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed)
static std::optional< uint64_t > gepToByteOffset(const DataLayout &dataLayout, LLVM::GEPOp gep)
Returns the amount of bytes the provided GEP elements will offset the pointer by.
static bool areAllIndicesI32(const DestructurableMemorySlot &slot)
Checks whether all indices are i32.
static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val, const DataLayout &dataLayout)
Converts a value to an integer type of the same size.
static Value castSameSizedTypes(OpBuilder &builder, Location loc, Value srcValue, Type targetType, const DataLayout &dataLayout)
Constructs operations that convert srcValue into a new value of type targetType.
static std::optional< SubslotAccessInfo > getSubslotAccessInfo(const DestructurableMemorySlot &slot, const DataLayout &dataLayout, LLVM::GEPOp gep)
Computes subslot access information for an access into slot with the given offset.
static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot)
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool areConversionCompatible(const DataLayout &layout, Type targetType, Type srcType, bool narrowingConversion)
Checks that rhs can be converted to lhs by a sequence of casts and truncations.
static bool forwardToUsers(Operation *op, SmallVectorImpl< OpOperand * > &newBlockingUses)
Conditions the deletion of the operation to the removal of all its uses.
static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot)
static bool isSupportedTypeForConversion(Type type)
Checks if type can be used in any kind of conversion sequences.
static Value createExtractAndCast(OpBuilder &builder, Location loc, Value srcValue, Type targetType, const DataLayout &dataLayout)
Constructs operations that convert srcValue into a new value of type targetType.
static Value createInsertAndCast(OpBuilder &builder, Location loc, Value srcValue, Value reachingDef, const DataLayout &dataLayout)
Constructs operations that insert the bits of srcValue into the "beginning" of reachingDef (beginning...
static DeletionKind memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, OpBuilder &builder, Value reachingDefinition)
static bool memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, SmallVectorImpl< OpOperand * > &newBlockingUses, const DataLayout &dataLayout)
static bool isBigEndian(const DataLayout &dataLayout)
Checks if dataLayout describes a little endian layout.
static bool hasAllZeroIndices(LLVM::GEPOp gepOp)
static bool isValidAccessType(const MemorySlot &slot, Type accessType, const DataLayout &dataLayout)
Checks if slot can be accessed through the provided access type.
static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, OpBuilder &builder)
static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc, Value val, Type targetType)
Converts a value with an integer type to targetType.
static DeletionKind memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, DenseMap< Attribute, MemorySlot > &subslots, OpBuilder &builder, const DataLayout &dataLayout)
Rewires a memcpy-like operation.
static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot, const DataLayout &dataLayout)
Returns whether one can be sure the memory intrinsic does not write outside of the bounds of the give...
static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, SmallPtrSetImpl< Attribute > &usedIndices, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed, const DataLayout &dataLayout)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
MLIRContext * getContext() const
Definition: Builders.h:55
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Attribute getEndianness() const
Returns the specified endianness.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:108
Type getTypeAtIndex(Attribute index)
Returns which type is stored at a given integer index within the struct.
bool isPacked() const
Checks if a struct is packed.
Definition: LLVMTypes.cpp:482
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:490
std::optional< DenseMap< Attribute, Type > > getSubelementIndexMap()
Destructs the struct into its indexed field types.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr int kGEPConstantBitWidth
Bit-width of a 'GEPConstantIndex' within GEPArg.
Definition: LLVMDialect.h:65
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > subelementTypes
Maps an index within the memory slot to the corresponding subelement type.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.