MLIR  19.0.0git
LLVMMemorySlot.cpp
Go to the documentation of this file.
1 //===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MemorySlot-related interfaces for LLVM dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define DEBUG_TYPE "sroa"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Interfaces for AllocaOp
29 //===----------------------------------------------------------------------===//
30 
31 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
32  if (!getOperation()->getBlock()->isEntryBlock())
33  return {};
34 
35  return {MemorySlot{getResult(), getElemType()}};
36 }
37 
38 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
39  RewriterBase &rewriter) {
40  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
41 }
42 
43 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
44  BlockArgument argument,
45  RewriterBase &rewriter) {
46  for (Operation *user : getOperation()->getUsers())
47  if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48  rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49  declareOp.getVarInfo(),
50  declareOp.getLocationExpr());
51 }
52 
53 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
54  Value defaultValue,
55  RewriterBase &rewriter) {
56  if (defaultValue && defaultValue.use_empty())
57  rewriter.eraseOp(defaultValue.getDefiningOp());
58  rewriter.eraseOp(*this);
59 }
60 
61 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
62  if (!mlir::matchPattern(getArraySize(), m_One()))
63  return {};
64 
65  auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType());
66  if (!destructurable)
67  return {};
68 
69  std::optional<DenseMap<Attribute, Type>> destructuredType =
70  destructurable.getSubelementIndexMap();
71  if (!destructuredType)
72  return {};
73 
74  return {DestructurableMemorySlot{{getResult(), getElemType()},
75  *destructuredType}};
76 }
77 
79 LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
80  const SmallPtrSetImpl<Attribute> &usedIndices,
81  RewriterBase &rewriter) {
82  assert(slot.ptr == getResult());
83  rewriter.setInsertionPointAfter(*this);
84 
85  auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
87  for (Attribute index : usedIndices) {
88  Type elemType = destructurableType.getTypeAtIndex(index);
89  assert(elemType && "used index must exist");
90  auto subAlloca = rewriter.create<LLVM::AllocaOp>(
91  getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
92  getArraySize());
93  slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
94  }
95 
96  return slotMap;
97 }
98 
99 void LLVM::AllocaOp::handleDestructuringComplete(
100  const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
101  assert(slot.ptr == getResult());
102  rewriter.eraseOp(*this);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Interfaces for LoadOp/StoreOp
107 //===----------------------------------------------------------------------===//
108 
109 bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
110  return getAddr() == slot.ptr;
111 }
112 
113 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
114 
115 Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
116  llvm_unreachable("getStored should not be called on LoadOp");
117 }
118 
119 bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
120 
121 bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
122  return getAddr() == slot.ptr;
123 }
124 
125 /// Checks that two types are the same or can be cast into one another.
126 static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
127  return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
128  !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
129  layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
130 }
131 
132 /// Constructs operations that convert `inputValue` into a new value of type
133 /// `targetType`. Assumes that this conversion is possible.
135  Value inputValue, Type targetType) {
136  if (inputValue.getType() == targetType)
137  return inputValue;
138 
139  if (!isa<LLVM::LLVMPointerType>(targetType) &&
140  !isa<LLVM::LLVMPointerType>(inputValue.getType()))
141  return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
142 
143  if (!isa<LLVM::LLVMPointerType>(targetType))
144  return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
145 
146  if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
147  return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
148 
149  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
150  inputValue);
151 }
152 
153 Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
154  return createConversionSequence(rewriter, getLoc(), getValue(),
155  slot.elemType);
156 }
157 
158 bool LLVM::LoadOp::canUsesBeRemoved(
159  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
160  SmallVectorImpl<OpOperand *> &newBlockingUses,
161  const DataLayout &dataLayout) {
162  if (blockingUses.size() != 1)
163  return false;
164  Value blockingUse = (*blockingUses.begin())->get();
165  // If the blocking use is the slot ptr itself, there will be enough
166  // context to reconstruct the result of the load at removal time, so it can
167  // be removed (provided it loads the exact stored value and is not
168  // volatile).
169  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
170  areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
171  !getVolatile_();
172 }
173 
174 DeletionKind LLVM::LoadOp::removeBlockingUses(
175  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
176  RewriterBase &rewriter, Value reachingDefinition) {
177  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
178  // pointer.
179  Value newResult = createConversionSequence(
180  rewriter, getLoc(), reachingDefinition, getResult().getType());
181  rewriter.replaceAllUsesWith(getResult(), newResult);
182  return DeletionKind::Delete;
183 }
184 
185 bool LLVM::StoreOp::canUsesBeRemoved(
186  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
187  SmallVectorImpl<OpOperand *> &newBlockingUses,
188  const DataLayout &dataLayout) {
189  if (blockingUses.size() != 1)
190  return false;
191  Value blockingUse = (*blockingUses.begin())->get();
192  // If the blocking use is the slot ptr itself, dropping the store is
193  // fine, provided we are currently promoting its target value. Don't allow a
194  // store OF the slot pointer, only INTO the slot pointer.
195  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
196  getValue() != slot.ptr &&
197  areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
198  !getVolatile_();
199 }
200 
201 DeletionKind LLVM::StoreOp::removeBlockingUses(
202  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
203  RewriterBase &rewriter, Value reachingDefinition) {
204  return DeletionKind::Delete;
205 }
206 
207 /// Checks if `slot` can be accessed through the provided access type.
208 static bool isValidAccessType(const MemorySlot &slot, Type accessType,
209  const DataLayout &dataLayout) {
210  return dataLayout.getTypeSize(accessType) <=
211  dataLayout.getTypeSize(slot.elemType);
212 }
213 
214 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
215  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
216  const DataLayout &dataLayout) {
217  return success(getAddr() != slot.ptr ||
218  isValidAccessType(slot, getType(), dataLayout));
219 }
220 
221 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
222  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
223  const DataLayout &dataLayout) {
224  return success(getAddr() != slot.ptr ||
225  isValidAccessType(slot, getValue().getType(), dataLayout));
226 }
227 
228 /// Returns the subslot's type at the requested index.
230  Attribute index) {
231  auto subelementIndexMap =
232  slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
233  if (!subelementIndexMap)
234  return {};
235  assert(!subelementIndexMap->empty());
236 
237  // Note: Returns a null-type when no entry was found.
238  return subelementIndexMap->lookup(index);
239 }
240 
241 bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
242  SmallPtrSetImpl<Attribute> &usedIndices,
243  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
244  const DataLayout &dataLayout) {
245  if (getVolatile_())
246  return false;
247 
248  // A load always accesses the first element of the destructured slot.
249  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
250  Type subslotType = getTypeAtIndex(slot, index);
251  if (!subslotType)
252  return false;
253 
254  // The access can only be replaced when the subslot is read within its bounds.
255  if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
256  return false;
257 
258  usedIndices.insert(index);
259  return true;
260 }
261 
262 DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
264  RewriterBase &rewriter,
265  const DataLayout &dataLayout) {
266  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
267  auto it = subslots.find(index);
268  assert(it != subslots.end());
269 
270  rewriter.modifyOpInPlace(
271  *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
272  return DeletionKind::Keep;
273 }
274 
275 bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
276  SmallPtrSetImpl<Attribute> &usedIndices,
277  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
278  const DataLayout &dataLayout) {
279  if (getVolatile_())
280  return false;
281 
282  // Storing the pointer to memory cannot be dealt with.
283  if (getValue() == slot.ptr)
284  return false;
285 
286  // A store always accesses the first element of the destructured slot.
287  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
288  Type subslotType = getTypeAtIndex(slot, index);
289  if (!subslotType)
290  return false;
291 
292  // The access can only be replaced when the subslot is read within its bounds.
293  if (dataLayout.getTypeSize(getValue().getType()) >
294  dataLayout.getTypeSize(subslotType))
295  return false;
296 
297  usedIndices.insert(index);
298  return true;
299 }
300 
301 DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
303  RewriterBase &rewriter,
304  const DataLayout &dataLayout) {
305  auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
306  auto it = subslots.find(index);
307  assert(it != subslots.end());
308 
309  rewriter.modifyOpInPlace(
310  *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
311  return DeletionKind::Keep;
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Interfaces for discardable OPs
316 //===----------------------------------------------------------------------===//
317 
318 /// Conditions the deletion of the operation to the removal of all its uses.
319 static bool forwardToUsers(Operation *op,
320  SmallVectorImpl<OpOperand *> &newBlockingUses) {
321  for (Value result : op->getResults())
322  for (OpOperand &use : result.getUses())
323  newBlockingUses.push_back(&use);
324  return true;
325 }
326 
327 bool LLVM::BitcastOp::canUsesBeRemoved(
328  const SmallPtrSetImpl<OpOperand *> &blockingUses,
329  SmallVectorImpl<OpOperand *> &newBlockingUses,
330  const DataLayout &dataLayout) {
331  return forwardToUsers(*this, newBlockingUses);
332 }
333 
334 DeletionKind LLVM::BitcastOp::removeBlockingUses(
335  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
336  return DeletionKind::Delete;
337 }
338 
339 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
340  const SmallPtrSetImpl<OpOperand *> &blockingUses,
341  SmallVectorImpl<OpOperand *> &newBlockingUses,
342  const DataLayout &dataLayout) {
343  return forwardToUsers(*this, newBlockingUses);
344 }
345 
346 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
347  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
348  return DeletionKind::Delete;
349 }
350 
351 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
352  const SmallPtrSetImpl<OpOperand *> &blockingUses,
353  SmallVectorImpl<OpOperand *> &newBlockingUses,
354  const DataLayout &dataLayout) {
355  return true;
356 }
357 
358 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
359  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
360  return DeletionKind::Delete;
361 }
362 
363 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
364  const SmallPtrSetImpl<OpOperand *> &blockingUses,
365  SmallVectorImpl<OpOperand *> &newBlockingUses,
366  const DataLayout &dataLayout) {
367  return true;
368 }
369 
370 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
371  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
372  return DeletionKind::Delete;
373 }
374 
375 bool LLVM::InvariantStartOp::canUsesBeRemoved(
376  const SmallPtrSetImpl<OpOperand *> &blockingUses,
377  SmallVectorImpl<OpOperand *> &newBlockingUses,
378  const DataLayout &dataLayout) {
379  return true;
380 }
381 
382 DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
383  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
384  return DeletionKind::Delete;
385 }
386 
387 bool LLVM::InvariantEndOp::canUsesBeRemoved(
388  const SmallPtrSetImpl<OpOperand *> &blockingUses,
389  SmallVectorImpl<OpOperand *> &newBlockingUses,
390  const DataLayout &dataLayout) {
391  return true;
392 }
393 
394 DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
395  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
396  return DeletionKind::Delete;
397 }
398 
399 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
400  const SmallPtrSetImpl<OpOperand *> &blockingUses,
401  SmallVectorImpl<OpOperand *> &newBlockingUses,
402  const DataLayout &dataLayout) {
403  return true;
404 }
405 
406 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
407  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
408  return DeletionKind::Delete;
409 }
410 
411 bool LLVM::DbgValueOp::canUsesBeRemoved(
412  const SmallPtrSetImpl<OpOperand *> &blockingUses,
413  SmallVectorImpl<OpOperand *> &newBlockingUses,
414  const DataLayout &dataLayout) {
415  // There is only one operand that we can remove the use of.
416  if (blockingUses.size() != 1)
417  return false;
418 
419  return (*blockingUses.begin())->get() == getValue();
420 }
421 
422 DeletionKind LLVM::DbgValueOp::removeBlockingUses(
423  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
424  // Rewriter by default is after '*this', but we need it before '*this'.
425  rewriter.setInsertionPoint(*this);
426 
427  // Rather than dropping the debug value, replace it with undef to preserve the
428  // debug local variable info. This allows the debugger to inform the user that
429  // the variable has been optimized out.
430  auto undef =
431  rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
432  rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); });
433  return DeletionKind::Keep;
434 }
435 
436 bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
437 
438 void LLVM::DbgDeclareOp::visitReplacedValues(
439  ArrayRef<std::pair<Operation *, Value>> definitions,
440  RewriterBase &rewriter) {
441  for (auto [op, value] : definitions) {
442  rewriter.setInsertionPointAfter(op);
443  rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
444  getLocationExpr());
445  }
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // Interfaces for GEPOp
450 //===----------------------------------------------------------------------===//
451 
452 static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
453  return llvm::all_of(gepOp.getIndices(), [](auto index) {
454  auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
455  return indexAttr && indexAttr.getValue() == 0;
456  });
457 }
458 
459 bool LLVM::GEPOp::canUsesBeRemoved(
460  const SmallPtrSetImpl<OpOperand *> &blockingUses,
461  SmallVectorImpl<OpOperand *> &newBlockingUses,
462  const DataLayout &dataLayout) {
463  // GEP can be removed as long as it is a no-op and its users can be removed.
464  if (!hasAllZeroIndices(*this))
465  return false;
466  return forwardToUsers(*this, newBlockingUses);
467 }
468 
469 DeletionKind LLVM::GEPOp::removeBlockingUses(
470  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
471  return DeletionKind::Delete;
472 }
473 
474 /// Returns the amount of bytes the provided GEP elements will offset the
475 /// pointer by. Returns nullopt if no constant offset could be computed.
476 static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
477  LLVM::GEPOp gep) {
478  // Collects all indices.
479  SmallVector<uint64_t> indices;
480  for (auto index : gep.getIndices()) {
481  auto constIndex = dyn_cast<IntegerAttr>(index);
482  if (!constIndex)
483  return {};
484  int64_t gepIndex = constIndex.getInt();
485  // Negative indices are not supported.
486  if (gepIndex < 0)
487  return {};
488  indices.push_back(gepIndex);
489  }
490 
491  Type currentType = gep.getElemType();
492  uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
493 
494  for (uint64_t index : llvm::drop_begin(indices)) {
495  bool shouldCancel =
496  TypeSwitch<Type, bool>(currentType)
497  .Case([&](LLVM::LLVMArrayType arrayType) {
498  offset +=
499  index * dataLayout.getTypeSize(arrayType.getElementType());
500  currentType = arrayType.getElementType();
501  return false;
502  })
503  .Case([&](LLVM::LLVMStructType structType) {
504  ArrayRef<Type> body = structType.getBody();
505  assert(index < body.size() && "expected valid struct indexing");
506  for (uint32_t i : llvm::seq(index)) {
507  if (!structType.isPacked())
508  offset = llvm::alignTo(
509  offset, dataLayout.getTypeABIAlignment(body[i]));
510  offset += dataLayout.getTypeSize(body[i]);
511  }
512 
513  // Align for the current type as well.
514  if (!structType.isPacked())
515  offset = llvm::alignTo(
516  offset, dataLayout.getTypeABIAlignment(body[index]));
517  currentType = body[index];
518  return false;
519  })
520  .Default([&](Type type) {
521  LLVM_DEBUG(llvm::dbgs()
522  << "[sroa] Unsupported type for offset computations"
523  << type << "\n");
524  return true;
525  });
526 
527  if (shouldCancel)
528  return std::nullopt;
529  }
530 
531  return offset;
532 }
533 
534 namespace {
535 /// A struct that stores both the index into the aggregate type of the slot as
536 /// well as the corresponding byte offset in memory.
537 struct SubslotAccessInfo {
538  /// The parent slot's index that the access falls into.
539  uint32_t index;
540  /// The offset into the subslot of the access.
541  uint64_t subslotOffset;
542 };
543 } // namespace
544 
545 /// Computes subslot access information for an access into `slot` with the given
546 /// offset.
547 /// Returns nullopt when the offset is out-of-bounds or when the access is into
548 /// the padding of `slot`.
549 static std::optional<SubslotAccessInfo>
551  const DataLayout &dataLayout, LLVM::GEPOp gep) {
552  std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
553  if (!offset)
554  return {};
555 
556  // Helper to check that a constant index is in the bounds of the GEP index
557  // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
558  // this additional check is necessary.
559  auto isOutOfBoundsGEPIndex = [](uint64_t index) {
560  return index >= (1 << LLVM::kGEPConstantBitWidth);
561  };
562 
563  Type type = slot.elemType;
564  if (*offset >= dataLayout.getTypeSize(type))
565  return {};
567  .Case([&](LLVM::LLVMArrayType arrayType)
568  -> std::optional<SubslotAccessInfo> {
569  // Find which element of the array contains the offset.
570  uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
571  uint64_t index = *offset / elemSize;
572  if (isOutOfBoundsGEPIndex(index))
573  return {};
574  return SubslotAccessInfo{static_cast<uint32_t>(index),
575  *offset - (index * elemSize)};
576  })
577  .Case([&](LLVM::LLVMStructType structType)
578  -> std::optional<SubslotAccessInfo> {
579  uint64_t distanceToStart = 0;
580  // Walk over the elements of the struct to find in which of
581  // them the offset is.
582  for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
583  uint64_t elemSize = dataLayout.getTypeSize(elem);
584  if (!structType.isPacked()) {
585  distanceToStart = llvm::alignTo(
586  distanceToStart, dataLayout.getTypeABIAlignment(elem));
587  // If the offset is in padding, cancel the rewrite.
588  if (offset < distanceToStart)
589  return {};
590  }
591 
592  if (offset < distanceToStart + elemSize) {
593  if (isOutOfBoundsGEPIndex(index))
594  return {};
595  // The offset is within this element, stop iterating the
596  // struct and return the index.
597  return SubslotAccessInfo{static_cast<uint32_t>(index),
598  *offset - distanceToStart};
599  }
600 
601  // The offset is not within this element, continue walking
602  // over the struct.
603  distanceToStart += elemSize;
604  }
605 
606  return {};
607  });
608 }
609 
610 /// Constructs a byte array type of the given size.
611 static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
612  unsigned size) {
613  auto byteType = IntegerType::get(context, 8);
614  return LLVM::LLVMArrayType::get(context, byteType, size);
615 }
616 
617 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
618  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
619  const DataLayout &dataLayout) {
620  if (getBase() != slot.ptr)
621  return success();
622  std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
623  if (!gepOffset)
624  return failure();
625  uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
626  // Check that the access is strictly inside the slot.
627  if (*gepOffset >= slotSize)
628  return failure();
629  // Every access that remains in bounds of the remaining slot is considered
630  // legal.
631  mustBeSafelyUsed.emplace_back<MemorySlot>(
632  {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
633  return success();
634 }
635 
636 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
637  SmallPtrSetImpl<Attribute> &usedIndices,
638  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
639  const DataLayout &dataLayout) {
640  if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
641  return false;
642 
643  if (getBase() != slot.ptr)
644  return false;
645  std::optional<SubslotAccessInfo> accessInfo =
646  getSubslotAccessInfo(slot, dataLayout, *this);
647  if (!accessInfo)
648  return false;
649  auto indexAttr =
650  IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
651  assert(slot.elementPtrs.contains(indexAttr));
652  usedIndices.insert(indexAttr);
653 
654  // The remainder of the subslot should be accesses in-bounds. Thus, we create
655  // a dummy slot with the size of the remainder.
656  Type subslotType = slot.elementPtrs.lookup(indexAttr);
657  uint64_t slotSize = dataLayout.getTypeSize(subslotType);
658  LLVM::LLVMArrayType remainingSlotType =
659  getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
660  mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
661 
662  return true;
663 }
664 
665 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
667  RewriterBase &rewriter,
668  const DataLayout &dataLayout) {
669  std::optional<SubslotAccessInfo> accessInfo =
670  getSubslotAccessInfo(slot, dataLayout, *this);
671  assert(accessInfo && "expected access info to be checked before");
672  auto indexAttr =
673  IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
674  const MemorySlot &newSlot = subslots.at(indexAttr);
675 
676  auto byteType = IntegerType::get(rewriter.getContext(), 8);
677  auto newPtr = rewriter.createOrFold<LLVM::GEPOp>(
678  getLoc(), getResult().getType(), byteType, newSlot.ptr,
679  ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
680  rewriter.replaceAllUsesWith(getResult(), newPtr);
681  return DeletionKind::Delete;
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // Utilities for memory intrinsics
686 //===----------------------------------------------------------------------===//
687 
688 namespace {
689 
690 /// Returns the length of the given memory intrinsic in bytes if it can be known
691 /// at compile-time on a best-effort basis, nothing otherwise.
692 template <class MemIntr>
693 std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
694  APInt memIntrLen;
695  if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
696  return {};
697  if (memIntrLen.getBitWidth() > 64)
698  return {};
699  return memIntrLen.getZExtValue();
700 }
701 
702 /// Returns the length of the given memory intrinsic in bytes if it can be known
703 /// at compile-time on a best-effort basis, nothing otherwise.
704 /// Because MemcpyInlineOp has its length encoded as an attribute, this requires
705 /// specialized handling.
706 template <>
707 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
708  APInt memIntrLen = op.getLen();
709  if (memIntrLen.getBitWidth() > 64)
710  return {};
711  return memIntrLen.getZExtValue();
712 }
713 
714 } // namespace
715 
716 /// Returns whether one can be sure the memory intrinsic does not write outside
717 /// of the bounds of the given slot, on a best-effort basis.
718 template <class MemIntr>
719 static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
720  const DataLayout &dataLayout) {
721  if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
722  op.getDst() != slot.ptr)
723  return false;
724 
725  std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
726  return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(slot.elemType);
727 }
728 
729 /// Checks whether all indices are i32. This is used to check GEPs can index
730 /// into them.
731 static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
732  Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
733  return llvm::all_of(llvm::make_first_range(slot.elementPtrs),
734  [&](Attribute index) {
735  auto intIndex = dyn_cast<IntegerAttr>(index);
736  return intIndex && intIndex.getType() == i32;
737  });
738 }
739 
740 //===----------------------------------------------------------------------===//
741 // Interfaces for memset
742 //===----------------------------------------------------------------------===//
743 
744 bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
745 
746 bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
747  return getDst() == slot.ptr;
748 }
749 
750 Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
751  RewriterBase &rewriter) {
752  // TODO: Support non-integer types.
753  return TypeSwitch<Type, Value>(slot.elemType)
754  .Case([&](IntegerType intType) -> Value {
755  if (intType.getWidth() == 8)
756  return getVal();
757 
758  assert(intType.getWidth() % 8 == 0);
759 
760  // Build the memset integer by repeatedly shifting the value and
761  // or-ing it with the previous value.
762  uint64_t coveredBits = 8;
763  Value currentValue =
764  rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
765  while (coveredBits < intType.getWidth()) {
766  Value shiftBy =
767  rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
768  Value shifted =
769  rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
770  currentValue =
771  rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
772  coveredBits *= 2;
773  }
774 
775  return currentValue;
776  })
777  .Default([](Type) -> Value {
778  llvm_unreachable(
779  "getStored should not be called on memset to unsupported type");
780  });
781 }
782 
783 bool LLVM::MemsetOp::canUsesBeRemoved(
784  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
785  SmallVectorImpl<OpOperand *> &newBlockingUses,
786  const DataLayout &dataLayout) {
787  // TODO: Support non-integer types.
788  bool canConvertType =
790  .Case([](IntegerType intType) {
791  return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
792  })
793  .Default([](Type) { return false; });
794  if (!canConvertType)
795  return false;
796 
797  if (getIsVolatile())
798  return false;
799 
800  return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
801 }
802 
803 DeletionKind LLVM::MemsetOp::removeBlockingUses(
804  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
805  RewriterBase &rewriter, Value reachingDefinition) {
806  return DeletionKind::Delete;
807 }
808 
809 LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
810  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
811  const DataLayout &dataLayout) {
812  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
813 }
814 
815 bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
816  SmallPtrSetImpl<Attribute> &usedIndices,
817  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
818  const DataLayout &dataLayout) {
819  if (&slot.elemType.getDialect() != getOperation()->getDialect())
820  return false;
821 
822  if (getIsVolatile())
823  return false;
824 
825  if (!slot.elemType.cast<DestructurableTypeInterface>()
826  .getSubelementIndexMap())
827  return false;
828 
829  if (!areAllIndicesI32(slot))
830  return false;
831 
832  return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
833 }
834 
835 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
837  RewriterBase &rewriter,
838  const DataLayout &dataLayout) {
839  std::optional<DenseMap<Attribute, Type>> types =
840  slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
841 
842  IntegerAttr memsetLenAttr;
843  bool successfulMatch =
844  matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
845  (void)successfulMatch;
846  assert(successfulMatch);
847 
848  bool packed = false;
849  if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
850  packed = structType.isPacked();
851 
852  Type i32 = IntegerType::get(getContext(), 32);
853  uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
854  uint64_t covered = 0;
855  for (size_t i = 0; i < types->size(); i++) {
856  // Create indices on the fly to get elements in the right order.
857  Attribute index = IntegerAttr::get(i32, i);
858  Type elemType = types->at(index);
859  uint64_t typeSize = dataLayout.getTypeSize(elemType);
860 
861  if (!packed)
862  covered =
863  llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
864 
865  if (covered >= memsetLen)
866  break;
867 
868  // If this subslot is used, apply a new memset to it.
869  // Otherwise, only compute its offset within the original memset.
870  if (subslots.contains(index)) {
871  uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
872 
873  Value newMemsetSizeValue =
874  rewriter
875  .create<LLVM::ConstantOp>(
876  getLen().getLoc(),
877  IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
878  .getResult();
879 
880  rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
881  getVal(), newMemsetSizeValue,
882  getIsVolatile());
883  }
884 
885  covered += typeSize;
886  }
887 
888  return DeletionKind::Delete;
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // Interfaces for memcpy/memmove
893 //===----------------------------------------------------------------------===//
894 
895 template <class MemcpyLike>
896 static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
897  return op.getSrc() == slot.ptr;
898 }
899 
900 template <class MemcpyLike>
901 static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
902  return op.getDst() == slot.ptr;
903 }
904 
905 template <class MemcpyLike>
906 static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
907  RewriterBase &rewriter) {
908  return rewriter.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
909 }
910 
911 template <class MemcpyLike>
912 static bool
913 memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
914  const SmallPtrSetImpl<OpOperand *> &blockingUses,
915  SmallVectorImpl<OpOperand *> &newBlockingUses,
916  const DataLayout &dataLayout) {
917  // If source and destination are the same, memcpy behavior is undefined and
918  // memmove is a no-op. Because there is no memory change happening here,
919  // simplifying such operations is left to canonicalization.
920  if (op.getDst() == op.getSrc())
921  return false;
922 
923  if (op.getIsVolatile())
924  return false;
925 
926  return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
927 }
928 
929 template <class MemcpyLike>
930 static DeletionKind
931 memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
932  const SmallPtrSetImpl<OpOperand *> &blockingUses,
933  RewriterBase &rewriter, Value reachingDefinition) {
934  if (op.loadsFrom(slot))
935  rewriter.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition,
936  op.getDst());
937  return DeletionKind::Delete;
938 }
939 
940 template <class MemcpyLike>
941 static LogicalResult
942 memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
943  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
944  DataLayout dataLayout = DataLayout::closest(op);
945  // While rewiring memcpy-like intrinsics only supports full copies, partial
946  // copies are still safe accesses so it is enough to only check for writes
947  // within bounds.
948  return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
949 }
950 
951 template <class MemcpyLike>
952 static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
953  SmallPtrSetImpl<Attribute> &usedIndices,
954  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
955  const DataLayout &dataLayout) {
956  if (op.getIsVolatile())
957  return false;
958 
959  if (!slot.elemType.cast<DestructurableTypeInterface>()
960  .getSubelementIndexMap())
961  return false;
962 
963  if (!areAllIndicesI32(slot))
964  return false;
965 
966  // Only full copies are supported.
967  if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
968  return false;
969 
970  if (op.getSrc() == slot.ptr)
971  for (Attribute index : llvm::make_first_range(slot.elementPtrs))
972  usedIndices.insert(index);
973 
974  return true;
975 }
976 
977 namespace {
978 
979 template <class MemcpyLike>
980 void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
981  MemcpyLike toReplace, Value dst, Value src,
982  Type toCpy, bool isVolatile) {
983  Value memcpySize = rewriter.create<LLVM::ConstantOp>(
984  toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
985  layout.getTypeSize(toCpy)));
986  rewriter.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
987  isVolatile);
988 }
989 
990 template <>
991 void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
992  LLVM::MemcpyInlineOp toReplace, Value dst,
993  Value src, Type toCpy, bool isVolatile) {
994  Type lenType = IntegerType::get(toReplace->getContext(),
995  toReplace.getLen().getBitWidth());
996  rewriter.create<LLVM::MemcpyInlineOp>(
997  toReplace.getLoc(), dst, src,
998  IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
999 }
1000 
1001 } // namespace
1002 
1003 /// Rewires a memcpy-like operation. Only copies to or from the full slot are
1004 /// supported.
1005 template <class MemcpyLike>
1006 static DeletionKind
1007 memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1008  DenseMap<Attribute, MemorySlot> &subslots, RewriterBase &rewriter,
1009  const DataLayout &dataLayout) {
1010  if (subslots.empty())
1011  return DeletionKind::Delete;
1012 
1013  assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
1014  bool isDst = slot.ptr == op.getDst();
1015 
1016 #ifndef NDEBUG
1017  size_t slotsTreated = 0;
1018 #endif
1019 
1020  // It was previously checked that index types are consistent, so this type can
1021  // be fetched now.
1022  Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
1023  for (size_t i = 0, e = slot.elementPtrs.size(); i != e; i++) {
1024  Attribute index = IntegerAttr::get(indexType, i);
1025  if (!subslots.contains(index))
1026  continue;
1027  const MemorySlot &subslot = subslots.at(index);
1028 
1029 #ifndef NDEBUG
1030  slotsTreated++;
1031 #endif
1032 
1033  // First get a pointer to the equivalent of this subslot from the source
1034  // pointer.
1035  SmallVector<LLVM::GEPArg> gepIndices{
1036  0, static_cast<int32_t>(
1037  cast<IntegerAttr>(index).getValue().getZExtValue())};
1038  Value subslotPtrInOther = rewriter.create<LLVM::GEPOp>(
1040  isDst ? op.getSrc() : op.getDst(), gepIndices);
1041 
1042  // Then create a new memcpy out of this source pointer.
1043  createMemcpyLikeToReplace(rewriter, dataLayout, op,
1044  isDst ? subslot.ptr : subslotPtrInOther,
1045  isDst ? subslotPtrInOther : subslot.ptr,
1046  subslot.elemType, op.getIsVolatile());
1047  }
1048 
1049  assert(subslots.size() == slotsTreated);
1050 
1051  return DeletionKind::Delete;
1052 }
1053 
1054 bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
1055  return memcpyLoadsFrom(*this, slot);
1056 }
1057 
1058 bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1059  return memcpyStoresTo(*this, slot);
1060 }
1061 
1062 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
1063  RewriterBase &rewriter) {
1064  return memcpyGetStored(*this, slot, rewriter);
1065 }
1066 
1067 bool LLVM::MemcpyOp::canUsesBeRemoved(
1068  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1069  SmallVectorImpl<OpOperand *> &newBlockingUses,
1070  const DataLayout &dataLayout) {
1071  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1072  dataLayout);
1073 }
1074 
1075 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
1076  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1077  RewriterBase &rewriter, Value reachingDefinition) {
1078  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1079  reachingDefinition);
1080 }
1081 
1082 LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
1083  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1084  const DataLayout &dataLayout) {
1085  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1086 }
1087 
1088 bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
1089  SmallPtrSetImpl<Attribute> &usedIndices,
1090  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1091  const DataLayout &dataLayout) {
1092  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1093  dataLayout);
1094 }
1095 
1096 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
1098  RewriterBase &rewriter,
1099  const DataLayout &dataLayout) {
1100  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1101 }
1102 
1103 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
1104  return memcpyLoadsFrom(*this, slot);
1105 }
1106 
1107 bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1108  return memcpyStoresTo(*this, slot);
1109 }
1110 
1111 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1112  RewriterBase &rewriter) {
1113  return memcpyGetStored(*this, slot, rewriter);
1114 }
1115 
1116 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1117  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1118  SmallVectorImpl<OpOperand *> &newBlockingUses,
1119  const DataLayout &dataLayout) {
1120  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1121  dataLayout);
1122 }
1123 
1124 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
1125  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1126  RewriterBase &rewriter, Value reachingDefinition) {
1127  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1128  reachingDefinition);
1129 }
1130 
1131 LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
1132  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1133  const DataLayout &dataLayout) {
1134  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1135 }
1136 
1137 bool LLVM::MemcpyInlineOp::canRewire(
1138  const DestructurableMemorySlot &slot,
1139  SmallPtrSetImpl<Attribute> &usedIndices,
1140  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1141  const DataLayout &dataLayout) {
1142  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1143  dataLayout);
1144 }
1145 
1147 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
1149  RewriterBase &rewriter,
1150  const DataLayout &dataLayout) {
1151  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1152 }
1153 
1154 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
1155  return memcpyLoadsFrom(*this, slot);
1156 }
1157 
1158 bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1159  return memcpyStoresTo(*this, slot);
1160 }
1161 
1162 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
1163  RewriterBase &rewriter) {
1164  return memcpyGetStored(*this, slot, rewriter);
1165 }
1166 
1167 bool LLVM::MemmoveOp::canUsesBeRemoved(
1168  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1169  SmallVectorImpl<OpOperand *> &newBlockingUses,
1170  const DataLayout &dataLayout) {
1171  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1172  dataLayout);
1173 }
1174 
1175 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
1176  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1177  RewriterBase &rewriter, Value reachingDefinition) {
1178  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1179  reachingDefinition);
1180 }
1181 
1182 LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
1183  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1184  const DataLayout &dataLayout) {
1185  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1186 }
1187 
1188 bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
1189  SmallPtrSetImpl<Attribute> &usedIndices,
1190  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1191  const DataLayout &dataLayout) {
1192  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1193  dataLayout);
1194 }
1195 
1196 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
1198  RewriterBase &rewriter,
1199  const DataLayout &dataLayout) {
1200  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1201 }
1202 
1203 //===----------------------------------------------------------------------===//
1204 // Interfaces for destructurable types
1205 //===----------------------------------------------------------------------===//
1206 
1207 std::optional<DenseMap<Attribute, Type>>
1209  Type i32 = IntegerType::get(getContext(), 32);
1210  DenseMap<Attribute, Type> destructured;
1211  for (const auto &[index, elemType] : llvm::enumerate(getBody()))
1212  destructured.insert({IntegerAttr::get(i32, index), elemType});
1213  return destructured;
1214 }
1215 
1217  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1218  if (!indexAttr || !indexAttr.getType().isInteger(32))
1219  return {};
1220  int32_t indexInt = indexAttr.getInt();
1221  ArrayRef<Type> body = getBody();
1222  if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
1223  return {};
1224  return body[indexInt];
1225 }
1226 
1227 std::optional<DenseMap<Attribute, Type>>
1228 LLVM::LLVMArrayType::getSubelementIndexMap() const {
1229  constexpr size_t maxArraySizeForDestructuring = 16;
1230  if (getNumElements() > maxArraySizeForDestructuring)
1231  return {};
1232  int32_t numElements = getNumElements();
1233 
1234  Type i32 = IntegerType::get(getContext(), 32);
1235  DenseMap<Attribute, Type> destructured;
1236  for (int32_t index = 0; index < numElements; ++index)
1237  destructured.insert({IntegerAttr::get(i32, index), getElementType()});
1238  return destructured;
1239 }
1240 
1242  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1243  if (!indexAttr || !indexAttr.getType().isInteger(32))
1244  return {};
1245  int32_t indexInt = indexAttr.getInt();
1246  if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
1247  return {};
1248  return getElementType();
1249 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context, unsigned size)
Constructs a byte array type of the given size.
static LogicalResult memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed)
static DeletionKind memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, RewriterBase &rewriter, Value reachingDefinition)
static std::optional< uint64_t > gepToByteOffset(const DataLayout &dataLayout, LLVM::GEPOp gep)
Returns the amount of bytes the provided GEP elements will offset the pointer by.
static bool areAllIndicesI32(const DestructurableMemorySlot &slot)
Checks whether all indices are i32.
static std::optional< SubslotAccessInfo > getSubslotAccessInfo(const DestructurableMemorySlot &slot, const DataLayout &dataLayout, LLVM::GEPOp gep)
Computes subslot access information for an access into slot with the given offset.
static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot)
static Value createConversionSequence(RewriterBase &rewriter, Location loc, Value inputValue, Type targetType)
Constructs operations that convert inputValue into a new value of type targetType.
static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index)
Returns the subslot's type at the requested index.
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs)
Checks that two types are the same or can be cast into one another.
static bool forwardToUsers(Operation *op, SmallVectorImpl< OpOperand * > &newBlockingUses)
Conditions the deletion of the operation to the removal of all its uses.
static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot)
static bool memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, SmallVectorImpl< OpOperand * > &newBlockingUses, const DataLayout &dataLayout)
static DeletionKind memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, DenseMap< Attribute, MemorySlot > &subslots, RewriterBase &rewriter, const DataLayout &dataLayout)
Rewires a memcpy-like operation.
static bool hasAllZeroIndices(LLVM::GEPOp gepOp)
static bool isValidAccessType(const MemorySlot &slot, Type accessType, const DataLayout &dataLayout)
Checks if slot can be accessed through the provided access type.
static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, RewriterBase &rewriter)
static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot, const DataLayout &dataLayout)
Returns whether one can be sure the memory intrinsic does not write outside of the bounds of the give...
static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, SmallPtrSetImpl< Attribute > &usedIndices, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed, const DataLayout &dataLayout)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
MLIRContext * getContext() const
Definition: Builders.h:55
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:109
Type getTypeAtIndex(Attribute index)
Returns which type is stored at a given integer index within the struct.
bool isPacked() const
Checks if a struct is packed.
Definition: LLVMTypes.cpp:482
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:490
std::optional< DenseMap< Attribute, Type > > getSubelementIndexMap()
Destructs the struct into its indexed field types.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:340
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:128
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr int kGEPConstantBitWidth
Bit-width of a 'GEPConstantIndex' within GEPArg.
Definition: LLVMDialect.h:65
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the corresponding subelement type.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.