MLIR  18.0.0git
LLVMMemorySlot.cpp
Go to the documentation of this file.
1 //===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MemorySlot-related interfaces for LLVM dialect
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/ValueRange.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Interfaces for AllocaOp
29 //===----------------------------------------------------------------------===//
30 
31 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
32  if (!getOperation()->getBlock()->isEntryBlock())
33  return {};
34 
35  return {MemorySlot{getResult(), getResultPtrElementType()}};
36 }
37 
38 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
39  RewriterBase &rewriter) {
40  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
41 }
42 
43 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
44  BlockArgument argument,
45  RewriterBase &rewriter) {
46  for (Operation *user : getOperation()->getUsers())
47  if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48  rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49  declareOp.getVarInfo());
50 }
51 
52 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
53  Value defaultValue,
54  RewriterBase &rewriter) {
55  if (defaultValue && defaultValue.use_empty())
56  rewriter.eraseOp(defaultValue.getDefiningOp());
57  rewriter.eraseOp(*this);
58 }
59 
60 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
61  if (!mlir::matchPattern(getArraySize(), m_One()))
62  return {};
63 
64  Type elemType = getResultPtrElementType();
65  auto destructurable = dyn_cast<DestructurableTypeInterface>(elemType);
66  if (!destructurable)
67  return {};
68 
69  std::optional<DenseMap<Attribute, Type>> destructuredType =
70  destructurable.getSubelementIndexMap();
71  if (!destructuredType)
72  return {};
73 
74  DenseMap<Attribute, Type> allocaTypeMap;
75  for (Attribute index : llvm::make_first_range(destructuredType.value()))
76  allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
77 
78  return {DestructurableMemorySlot{{getResult(), elemType}, {allocaTypeMap}}};
79 }
80 
82 LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
83  const SmallPtrSetImpl<Attribute> &usedIndices,
84  RewriterBase &rewriter) {
85  assert(slot.ptr == getResult());
86  Type elemType =
87  getElemType() ? *getElemType() : getResult().getType().getElementType();
88 
89  rewriter.setInsertionPointAfter(*this);
90 
91  auto destructurableType = cast<DestructurableTypeInterface>(elemType);
93  for (Attribute index : usedIndices) {
94  Type elemType = destructurableType.getTypeAtIndex(index);
95  assert(elemType && "used index must exist");
96  auto subAlloca = rewriter.create<LLVM::AllocaOp>(
97  getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
98  getArraySize());
99  slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
100  }
101 
102  return slotMap;
103 }
104 
105 void LLVM::AllocaOp::handleDestructuringComplete(
106  const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
107  assert(slot.ptr == getResult());
108  rewriter.eraseOp(*this);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Interfaces for LoadOp/StoreOp
113 //===----------------------------------------------------------------------===//
114 
115 bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
116  return getAddr() == slot.ptr;
117 }
118 
119 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
120 
121 Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
122  llvm_unreachable("getStored should not be called on LoadOp");
123 }
124 
125 bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
126 
127 bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
128  return getAddr() == slot.ptr;
129 }
130 
131 Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
132  return getValue();
133 }
134 
135 bool LLVM::LoadOp::canUsesBeRemoved(
136  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
137  SmallVectorImpl<OpOperand *> &newBlockingUses) {
138  if (blockingUses.size() != 1)
139  return false;
140  Value blockingUse = (*blockingUses.begin())->get();
141  // If the blocking use is the slot ptr itself, there will be enough
142  // context to reconstruct the result of the load at removal time, so it can
143  // be removed (provided it loads the exact stored value and is not
144  // volatile).
145  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
146  getResult().getType() == slot.elemType && !getVolatile_();
147 }
148 
149 DeletionKind LLVM::LoadOp::removeBlockingUses(
150  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
151  RewriterBase &rewriter, Value reachingDefinition) {
152  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
153  // pointer.
154  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
155  return DeletionKind::Delete;
156 }
157 
158 bool LLVM::StoreOp::canUsesBeRemoved(
159  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
160  SmallVectorImpl<OpOperand *> &newBlockingUses) {
161  if (blockingUses.size() != 1)
162  return false;
163  Value blockingUse = (*blockingUses.begin())->get();
164  // If the blocking use is the slot ptr itself, dropping the store is
165  // fine, provided we are currently promoting its target value. Don't allow a
166  // store OF the slot pointer, only INTO the slot pointer.
167  return blockingUse == slot.ptr && getAddr() == slot.ptr &&
168  getValue() != slot.ptr && getValue().getType() == slot.elemType &&
169  !getVolatile_();
170 }
171 
172 DeletionKind LLVM::StoreOp::removeBlockingUses(
173  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
174  RewriterBase &rewriter, Value reachingDefinition) {
175  // `canUsesBeRemoved` checked this blocking use must be the stored slot
176  // pointer.
177  for (Operation *user : slot.ptr.getUsers())
178  if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
179  rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
180  declareOp.getVarInfo());
181  return DeletionKind::Delete;
182 }
183 
184 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
185  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
186  return success(getAddr() != slot.ptr || getType() == slot.elemType);
187 }
188 
189 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
190  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
191  return success(getAddr() != slot.ptr ||
192  getValue().getType() == slot.elemType);
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Interfaces for discardable OPs
197 //===----------------------------------------------------------------------===//
198 
199 /// Conditions the deletion of the operation to the removal of all its uses.
200 static bool forwardToUsers(Operation *op,
201  SmallVectorImpl<OpOperand *> &newBlockingUses) {
202  for (Value result : op->getResults())
203  for (OpOperand &use : result.getUses())
204  newBlockingUses.push_back(&use);
205  return true;
206 }
207 
208 bool LLVM::BitcastOp::canUsesBeRemoved(
209  const SmallPtrSetImpl<OpOperand *> &blockingUses,
210  SmallVectorImpl<OpOperand *> &newBlockingUses) {
211  return forwardToUsers(*this, newBlockingUses);
212 }
213 
214 DeletionKind LLVM::BitcastOp::removeBlockingUses(
215  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
216  return DeletionKind::Delete;
217 }
218 
219 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
220  const SmallPtrSetImpl<OpOperand *> &blockingUses,
221  SmallVectorImpl<OpOperand *> &newBlockingUses) {
222  return forwardToUsers(*this, newBlockingUses);
223 }
224 
225 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
226  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
227  return DeletionKind::Delete;
228 }
229 
230 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
231  const SmallPtrSetImpl<OpOperand *> &blockingUses,
232  SmallVectorImpl<OpOperand *> &newBlockingUses) {
233  return true;
234 }
235 
236 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
237  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
238  return DeletionKind::Delete;
239 }
240 
241 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
242  const SmallPtrSetImpl<OpOperand *> &blockingUses,
243  SmallVectorImpl<OpOperand *> &newBlockingUses) {
244  return true;
245 }
246 
247 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
248  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
249  return DeletionKind::Delete;
250 }
251 
252 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
253  const SmallPtrSetImpl<OpOperand *> &blockingUses,
254  SmallVectorImpl<OpOperand *> &newBlockingUses) {
255  return true;
256 }
257 
258 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
259  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
260  return DeletionKind::Delete;
261 }
262 
263 bool LLVM::DbgValueOp::canUsesBeRemoved(
264  const SmallPtrSetImpl<OpOperand *> &blockingUses,
265  SmallVectorImpl<OpOperand *> &newBlockingUses) {
266  // There is only one operand that we can remove the use of.
267  if (blockingUses.size() != 1)
268  return false;
269 
270  return (*blockingUses.begin())->get() == getValue();
271 }
272 
273 DeletionKind LLVM::DbgValueOp::removeBlockingUses(
274  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
275  // Rewriter by default is after '*this', but we need it before '*this'.
276  rewriter.setInsertionPoint(*this);
277 
278  // Rather than dropping the debug value, replace it with undef to preserve the
279  // debug local variable info. This allows the debugger to inform the user that
280  // the variable has been optimized out.
281  auto undef =
282  rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
283  rewriter.updateRootInPlace(*this, [&] { getValueMutable().assign(undef); });
284  return DeletionKind::Keep;
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // Interfaces for GEPOp
289 //===----------------------------------------------------------------------===//
290 
291 static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
292  return llvm::all_of(gepOp.getIndices(), [](auto index) {
293  auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
294  return indexAttr && indexAttr.getValue() == 0;
295  });
296 }
297 
298 bool LLVM::GEPOp::canUsesBeRemoved(
299  const SmallPtrSetImpl<OpOperand *> &blockingUses,
300  SmallVectorImpl<OpOperand *> &newBlockingUses) {
301  // GEP can be removed as long as it is a no-op and its users can be removed.
302  if (!hasAllZeroIndices(*this))
303  return false;
304  return forwardToUsers(*this, newBlockingUses);
305 }
306 
307 DeletionKind LLVM::GEPOp::removeBlockingUses(
308  const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
309  return DeletionKind::Delete;
310 }
311 
312 static bool isFirstIndexZero(LLVM::GEPOp gep) {
313  IntegerAttr index =
314  llvm::dyn_cast_if_present<IntegerAttr>(gep.getIndices()[0]);
315  return index && index.getInt() == 0;
316 }
317 
318 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
319  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
320  if (getBase() != slot.ptr)
321  return success();
322  if (slot.elemType != getElemType())
323  return failure();
324  if (!isFirstIndexZero(*this))
325  return failure();
326  Type reachedType = getResultPtrElementType();
327  if (!reachedType)
328  return failure();
329  mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
330  return success();
331 }
332 
333 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
334  SmallPtrSetImpl<Attribute> &usedIndices,
335  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
336  auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
337  if (!basePtrType)
338  return false;
339 
340  // Typed pointers are not supported. This should be removed once typed
341  // pointers are removed from the LLVM dialect.
342  if (!basePtrType.isOpaque())
343  return false;
344 
345  if (getBase() != slot.ptr || slot.elemType != getElemType())
346  return false;
347  if (!isFirstIndexZero(*this))
348  return false;
349  Type reachedType = getResultPtrElementType();
350  if (!reachedType || getIndices().size() < 2)
351  return false;
352  auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
353  if (!firstLevelIndex)
354  return false;
355  assert(slot.elementPtrs.contains(firstLevelIndex));
356  if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
357  return false;
358  mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
359  usedIndices.insert(firstLevelIndex);
360  return true;
361 }
362 
363 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
365  RewriterBase &rewriter) {
366  IntegerAttr firstLevelIndex =
367  llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
368  const MemorySlot &newSlot = subslots.at(firstLevelIndex);
369 
370  ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
371 
372  // If the GEP would become trivial after this transformation, eliminate it.
373  // A GEP should only be eliminated if it has no indices (except the first
374  // pointer index), as simplifying GEPs with all-zero indices would eliminate
375  // structure information useful for further destruction.
376  if (remainingIndices.empty()) {
377  rewriter.replaceAllUsesWith(getResult(), newSlot.ptr);
378  return DeletionKind::Delete;
379  }
380 
381  rewriter.updateRootInPlace(*this, [&]() {
382  // Rewire the indices by popping off the second index.
383  // Start with a single zero, then add the indices beyond the second.
384  SmallVector<int32_t> newIndices(1);
385  newIndices.append(remainingIndices.begin(), remainingIndices.end());
386  setRawConstantIndices(newIndices);
387 
388  // Rewire the pointed type.
389  setElemType(newSlot.elemType);
390 
391  // Rewire the pointer.
392  getBaseMutable().assign(newSlot.ptr);
393  });
394 
395  return DeletionKind::Keep;
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // Utilities for memory intrinsics
400 //===----------------------------------------------------------------------===//
401 
402 namespace {
403 
404 /// Returns the length of the given memory intrinsic in bytes if it can be known
405 /// at compile-time on a best-effort basis, nothing otherwise.
406 template <class MemIntr>
407 std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
408  APInt memIntrLen;
409  if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
410  return {};
411  if (memIntrLen.getBitWidth() > 64)
412  return {};
413  return memIntrLen.getZExtValue();
414 }
415 
416 /// Returns the length of the given memory intrinsic in bytes if it can be known
417 /// at compile-time on a best-effort basis, nothing otherwise.
418 /// Because MemcpyInlineOp has its length encoded as an attribute, this requires
419 /// specialized handling.
420 template <>
421 std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
422  APInt memIntrLen = op.getLen();
423  if (memIntrLen.getBitWidth() > 64)
424  return {};
425  return memIntrLen.getZExtValue();
426 }
427 
428 } // namespace
429 
430 /// Returns whether one can be sure the memory intrinsic does not write outside
431 /// of the bounds of the given slot, on a best-effort basis.
432 template <class MemIntr>
433 static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
434  DataLayout &dataLayout) {
435  if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
436  op.getDst() != slot.ptr)
437  return false;
438 
439  std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
440  return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(slot.elemType);
441 }
442 
443 /// Checks whether all indices are i32. This is used to check GEPs can index
444 /// into them.
445 static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
446  Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
447  return llvm::all_of(llvm::make_first_range(slot.elementPtrs),
448  [&](Attribute index) {
449  auto intIndex = dyn_cast<IntegerAttr>(index);
450  return intIndex && intIndex.getType() == i32;
451  });
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // Interfaces for memset
456 //===----------------------------------------------------------------------===//
457 
458 bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
459 
460 bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
461  return getDst() == slot.ptr;
462 }
463 
464 Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
465  RewriterBase &rewriter) {
466  // TODO: Support non-integer types.
467  return TypeSwitch<Type, Value>(slot.elemType)
468  .Case([&](IntegerType intType) -> Value {
469  if (intType.getWidth() == 8)
470  return getVal();
471 
472  assert(intType.getWidth() % 8 == 0);
473 
474  // Build the memset integer by repeatedly shifting the value and
475  // or-ing it with the previous value.
476  uint64_t coveredBits = 8;
477  Value currentValue =
478  rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
479  while (coveredBits < intType.getWidth()) {
480  Value shiftBy =
481  rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
482  Value shifted =
483  rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
484  currentValue =
485  rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
486  coveredBits *= 2;
487  }
488 
489  return currentValue;
490  })
491  .Default([](Type) -> Value {
492  llvm_unreachable(
493  "getStored should not be called on memset to unsupported type");
494  });
495 }
496 
497 bool LLVM::MemsetOp::canUsesBeRemoved(
498  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
499  SmallVectorImpl<OpOperand *> &newBlockingUses) {
500  // TODO: Support non-integer types.
501  bool canConvertType =
503  .Case([](IntegerType intType) {
504  return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
505  })
506  .Default([](Type) { return false; });
507  if (!canConvertType)
508  return false;
509 
510  if (getIsVolatile())
511  return false;
512 
513  DataLayout layout = DataLayout::closest(*this);
514  return getStaticMemIntrLen(*this) == layout.getTypeSize(slot.elemType);
515 }
516 
517 DeletionKind LLVM::MemsetOp::removeBlockingUses(
518  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
519  RewriterBase &rewriter, Value reachingDefinition) {
520  return DeletionKind::Delete;
521 }
522 
523 LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
524  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
525  DataLayout dataLayout = DataLayout::closest(*this);
526  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
527 }
528 
529 bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
530  SmallPtrSetImpl<Attribute> &usedIndices,
531  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
532  if (&slot.elemType.getDialect() != getOperation()->getDialect())
533  return false;
534 
535  if (getIsVolatile())
536  return false;
537 
538  if (!slot.elemType.cast<DestructurableTypeInterface>()
539  .getSubelementIndexMap())
540  return false;
541 
542  if (!areAllIndicesI32(slot))
543  return false;
544 
545  DataLayout dataLayout = DataLayout::closest(*this);
546  return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
547 }
548 
549 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
551  RewriterBase &rewriter) {
552  std::optional<DenseMap<Attribute, Type>> types =
553  slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
554 
555  IntegerAttr memsetLenAttr;
556  bool successfulMatch =
557  matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
558  (void)successfulMatch;
559  assert(successfulMatch);
560 
561  bool packed = false;
562  if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
563  packed = structType.isPacked();
564 
565  Type i32 = IntegerType::get(getContext(), 32);
566  DataLayout dataLayout = DataLayout::closest(*this);
567  uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
568  uint64_t covered = 0;
569  for (size_t i = 0; i < types->size(); i++) {
570  // Create indices on the fly to get elements in the right order.
571  Attribute index = IntegerAttr::get(i32, i);
572  Type elemType = types->at(index);
573  uint64_t typeSize = dataLayout.getTypeSize(elemType);
574 
575  if (!packed)
576  covered =
577  llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
578 
579  if (covered >= memsetLen)
580  break;
581 
582  // If this subslot is used, apply a new memset to it.
583  // Otherwise, only compute its offset within the original memset.
584  if (subslots.contains(index)) {
585  uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
586 
587  Value newMemsetSizeValue =
588  rewriter
589  .create<LLVM::ConstantOp>(
590  getLen().getLoc(),
591  IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
592  .getResult();
593 
594  rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
595  getVal(), newMemsetSizeValue,
596  getIsVolatile());
597  }
598 
599  covered += typeSize;
600  }
601 
602  return DeletionKind::Delete;
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // Interfaces for memcpy/memmove
607 //===----------------------------------------------------------------------===//
608 
609 template <class MemcpyLike>
610 static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
611  return op.getSrc() == slot.ptr;
612 }
613 
614 template <class MemcpyLike>
615 static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
616  return op.getDst() == slot.ptr;
617 }
618 
619 template <class MemcpyLike>
620 static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
621  RewriterBase &rewriter) {
622  return rewriter.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
623 }
624 
625 template <class MemcpyLike>
626 static bool
627 memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
628  const SmallPtrSetImpl<OpOperand *> &blockingUses,
629  SmallVectorImpl<OpOperand *> &newBlockingUses) {
630  // If source and destination are the same, memcpy behavior is undefined and
631  // memmove is a no-op. Because there is no memory change happening here,
632  // simplifying such operations is left to canonicalization.
633  if (op.getDst() == op.getSrc())
634  return false;
635 
636  if (op.getIsVolatile())
637  return false;
638 
639  DataLayout layout = DataLayout::closest(op);
640  return getStaticMemIntrLen(op) == layout.getTypeSize(slot.elemType);
641 }
642 
643 template <class MemcpyLike>
644 static DeletionKind
645 memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
646  const SmallPtrSetImpl<OpOperand *> &blockingUses,
647  RewriterBase &rewriter, Value reachingDefinition) {
648  if (op.loadsFrom(slot))
649  rewriter.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition,
650  op.getDst());
651  return DeletionKind::Delete;
652 }
653 
654 template <class MemcpyLike>
655 static LogicalResult
656 memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
657  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
658  DataLayout dataLayout = DataLayout::closest(op);
659  // While rewiring memcpy-like intrinsics only supports full copies, partial
660  // copies are still safe accesses so it is enough to only check for writes
661  // within bounds.
662  return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
663 }
664 
665 template <class MemcpyLike>
666 static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
667  SmallPtrSetImpl<Attribute> &usedIndices,
668  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
669  if (op.getIsVolatile())
670  return false;
671 
672  if (!slot.elemType.cast<DestructurableTypeInterface>()
673  .getSubelementIndexMap())
674  return false;
675 
676  if (!areAllIndicesI32(slot))
677  return false;
678 
679  // Only full copies are supported.
680  DataLayout dataLayout = DataLayout::closest(op);
681  if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
682  return false;
683 
684  if (op.getSrc() == slot.ptr)
685  for (Attribute index : llvm::make_first_range(slot.elementPtrs))
686  usedIndices.insert(index);
687 
688  return true;
689 }
690 
691 namespace {
692 
693 template <class MemcpyLike>
694 void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
695  MemcpyLike toReplace, Value dst, Value src,
696  Type toCpy, bool isVolatile) {
697  Value memcpySize = rewriter.create<LLVM::ConstantOp>(
698  toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
699  layout.getTypeSize(toCpy)));
700  rewriter.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
701  isVolatile);
702 }
703 
704 template <>
705 void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
706  LLVM::MemcpyInlineOp toReplace, Value dst,
707  Value src, Type toCpy, bool isVolatile) {
708  Type lenType = IntegerType::get(toReplace->getContext(),
709  toReplace.getLen().getBitWidth());
710  rewriter.create<LLVM::MemcpyInlineOp>(
711  toReplace.getLoc(), dst, src,
712  IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
713 }
714 
715 } // namespace
716 
717 /// Rewires a memcpy-like operation. Only copies to or from the full slot are
718 /// supported.
719 template <class MemcpyLike>
720 static DeletionKind memcpyRewire(MemcpyLike op,
721  const DestructurableMemorySlot &slot,
723  RewriterBase &rewriter) {
724  if (subslots.empty())
725  return DeletionKind::Delete;
726 
727  DataLayout layout = DataLayout::closest(op);
728 
729  assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
730  bool isDst = slot.ptr == op.getDst();
731 
732 #ifndef NDEBUG
733  size_t slotsTreated = 0;
734 #endif
735 
736  // It was previously checked that index types are consistent, so this type can
737  // be fetched now.
738  Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
739  for (size_t i = 0, e = slot.elementPtrs.size(); i != e; i++) {
740  Attribute index = IntegerAttr::get(indexType, i);
741  if (!subslots.contains(index))
742  continue;
743  const MemorySlot &subslot = subslots.at(index);
744 
745 #ifndef NDEBUG
746  slotsTreated++;
747 #endif
748 
749  // First get a pointer to the equivalent of this subslot from the source
750  // pointer.
751  SmallVector<LLVM::GEPArg> gepIndices{
752  0, static_cast<int32_t>(
753  cast<IntegerAttr>(index).getValue().getZExtValue())};
754  Value subslotPtrInOther = rewriter.create<LLVM::GEPOp>(
756  isDst ? op.getSrc() : op.getDst(), gepIndices);
757 
758  // Then create a new memcpy out of this source pointer.
759  createMemcpyLikeToReplace(rewriter, layout, op,
760  isDst ? subslot.ptr : subslotPtrInOther,
761  isDst ? subslotPtrInOther : subslot.ptr,
762  subslot.elemType, op.getIsVolatile());
763  }
764 
765  assert(subslots.size() == slotsTreated);
766 
767  return DeletionKind::Delete;
768 }
769 
770 bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
771  return memcpyLoadsFrom(*this, slot);
772 }
773 
774 bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
775  return memcpyStoresTo(*this, slot);
776 }
777 
778 Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
779  RewriterBase &rewriter) {
780  return memcpyGetStored(*this, slot, rewriter);
781 }
782 
783 bool LLVM::MemcpyOp::canUsesBeRemoved(
784  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
785  SmallVectorImpl<OpOperand *> &newBlockingUses) {
786  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
787 }
788 
789 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
790  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
791  RewriterBase &rewriter, Value reachingDefinition) {
792  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
793  reachingDefinition);
794 }
795 
796 LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
797  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
798  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
799 }
800 
801 bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
802  SmallPtrSetImpl<Attribute> &usedIndices,
803  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
804  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
805 }
806 
807 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
809  RewriterBase &rewriter) {
810  return memcpyRewire(*this, slot, subslots, rewriter);
811 }
812 
813 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
814  return memcpyLoadsFrom(*this, slot);
815 }
816 
817 bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
818  return memcpyStoresTo(*this, slot);
819 }
820 
821 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
822  RewriterBase &rewriter) {
823  return memcpyGetStored(*this, slot, rewriter);
824 }
825 
826 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
827  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
828  SmallVectorImpl<OpOperand *> &newBlockingUses) {
829  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
830 }
831 
832 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
833  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
834  RewriterBase &rewriter, Value reachingDefinition) {
835  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
836  reachingDefinition);
837 }
838 
839 LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
840  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
841  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
842 }
843 
844 bool LLVM::MemcpyInlineOp::canRewire(
845  const DestructurableMemorySlot &slot,
846  SmallPtrSetImpl<Attribute> &usedIndices,
847  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
848  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
849 }
850 
852 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
854  RewriterBase &rewriter) {
855  return memcpyRewire(*this, slot, subslots, rewriter);
856 }
857 
858 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
859  return memcpyLoadsFrom(*this, slot);
860 }
861 
862 bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
863  return memcpyStoresTo(*this, slot);
864 }
865 
866 Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
867  RewriterBase &rewriter) {
868  return memcpyGetStored(*this, slot, rewriter);
869 }
870 
871 bool LLVM::MemmoveOp::canUsesBeRemoved(
872  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
873  SmallVectorImpl<OpOperand *> &newBlockingUses) {
874  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
875 }
876 
877 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
878  const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
879  RewriterBase &rewriter, Value reachingDefinition) {
880  return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
881  reachingDefinition);
882 }
883 
884 LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
885  const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
886  return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
887 }
888 
889 bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
890  SmallPtrSetImpl<Attribute> &usedIndices,
891  SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
892  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
893 }
894 
895 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
897  RewriterBase &rewriter) {
898  return memcpyRewire(*this, slot, subslots, rewriter);
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // Interfaces for destructurable types
903 //===----------------------------------------------------------------------===//
904 
905 std::optional<DenseMap<Attribute, Type>>
907  Type i32 = IntegerType::get(getContext(), 32);
908  DenseMap<Attribute, Type> destructured;
909  for (const auto &[index, elemType] : llvm::enumerate(getBody()))
910  destructured.insert({IntegerAttr::get(i32, index), elemType});
911  return destructured;
912 }
913 
915  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
916  if (!indexAttr || !indexAttr.getType().isInteger(32))
917  return {};
918  int32_t indexInt = indexAttr.getInt();
919  ArrayRef<Type> body = getBody();
920  if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
921  return {};
922  return body[indexInt];
923 }
924 
925 std::optional<DenseMap<Attribute, Type>>
926 LLVM::LLVMArrayType::getSubelementIndexMap() const {
927  constexpr size_t maxArraySizeForDestructuring = 16;
928  if (getNumElements() > maxArraySizeForDestructuring)
929  return {};
930  int32_t numElements = getNumElements();
931 
932  Type i32 = IntegerType::get(getContext(), 32);
933  DenseMap<Attribute, Type> destructured;
934  for (int32_t index = 0; index < numElements; ++index)
935  destructured.insert({IntegerAttr::get(i32, index), getElementType()});
936  return destructured;
937 }
938 
939 Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
940  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
941  if (!indexAttr || !indexAttr.getType().isInteger(32))
942  return {};
943  int32_t indexInt = indexAttr.getInt();
944  if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
945  return {};
946  return getElementType();
947 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed)
static DeletionKind memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, RewriterBase &rewriter, Value reachingDefinition)
static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot, DataLayout &dataLayout)
Returns whether one can be sure the memory intrinsic does not write outside of the bounds of the give...
static bool areAllIndicesI32(const DestructurableMemorySlot &slot)
Checks whether all indices are i32.
static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot)
static bool isFirstIndexZero(LLVM::GEPOp gep)
static bool forwardToUsers(Operation *op, SmallVectorImpl< OpOperand * > &newBlockingUses)
Conditions the deletion of the operation to the removal of all its uses.
static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot)
static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, SmallPtrSetImpl< Attribute > &usedIndices, SmallVectorImpl< MemorySlot > &mustBeSafelyUsed)
static bool hasAllZeroIndices(LLVM::GEPOp gepOp)
static bool memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot, const SmallPtrSetImpl< OpOperand * > &blockingUses, SmallVectorImpl< OpOperand * > &newBlockingUses)
static DeletionKind memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, DenseMap< Attribute, MemorySlot > &subslots, RewriterBase &rewriter)
Rewires a memcpy-like operation.
static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, RewriterBase &rewriter)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1333
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
unsigned getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Type getTypeAtIndex(Attribute index)
Returns which type is stored at a given integer index within the struct.
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:539
std::optional< DenseMap< Attribute, Type > > getSubelementIndexMap()
Destructs the struct into its indexed field types.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:261
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:212
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:125
Type getType() const
Return the type of this value.
Definition: Value.h:122
user_range getUsers() const
Definition: Value.h:222
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
DeletionKind
Returned by operation promotion logic requesting the deletion of an operation.
@ Keep
Keep the operation after promotion.
@ Delete
Delete the operation after promotion.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Memory slot attached with information about its destructuring procedure.
DenseMap< Attribute, Type > elementPtrs
Maps an index within the memory slot to the type of the pointer that will be generated to access the ...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.