MLIR  19.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
28 
30  assign(attributes.begin(), attributes.end());
31 }
32 
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34  : NamedAttrList(attributes ? attributes.getValue()
35  : ArrayRef<NamedAttribute>()) {
36  dictionarySorted.setPointerAndInt(attributes, true);
37 }
38 
40  assign(inStart, inEnd);
41 }
42 
44 
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46  std::optional<NamedAttribute> duplicate =
47  DictionaryAttr::findDuplicate(attrs, isSorted());
48  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49  // state.
50  if (!isSorted())
51  dictionarySorted.setPointerAndInt(nullptr, true);
52  return duplicate;
53 }
54 
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56  if (!isSorted()) {
57  DictionaryAttr::sortInPlace(attrs);
58  dictionarySorted.setPointerAndInt(nullptr, true);
59  }
60  if (!dictionarySorted.getPointer())
61  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62  return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63 }
64 
65 /// Add an attribute with the specified name.
66 void NamedAttrList::append(StringRef name, Attribute attr) {
67  append(StringAttr::get(attr.getContext(), name), attr);
68 }
69 
70 /// Replaces the attributes with new list of attributes.
72  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
73  dictionarySorted.setPointerAndInt(nullptr, true);
74 }
75 
77  if (isSorted())
78  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
79  dictionarySorted.setPointer(nullptr);
80  attrs.push_back(newAttribute);
81 }
82 
83 /// Return the specified attribute if present, null otherwise.
84 Attribute NamedAttrList::get(StringRef name) const {
85  auto it = findAttr(*this, name);
86  return it.second ? it.first->getValue() : Attribute();
87 }
88 Attribute NamedAttrList::get(StringAttr name) const {
89  auto it = findAttr(*this, name);
90  return it.second ? it.first->getValue() : Attribute();
91 }
92 
93 /// Return the specified named attribute if present, std::nullopt otherwise.
94 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
95  auto it = findAttr(*this, name);
96  return it.second ? *it.first : std::optional<NamedAttribute>();
97 }
98 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
99  auto it = findAttr(*this, name);
100  return it.second ? *it.first : std::optional<NamedAttribute>();
101 }
102 
103 /// If the an attribute exists with the specified name, change it to the new
104 /// value. Otherwise, add a new attribute with the specified name/value.
105 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
106  assert(value && "attributes may never be null");
107 
108  // Look for an existing attribute with the given name, and set its value
109  // in-place. Return the previous value of the attribute, if there was one.
110  auto it = findAttr(*this, name);
111  if (it.second) {
112  // Update the existing attribute by swapping out the old value for the new
113  // value. Return the old value.
114  Attribute oldValue = it.first->getValue();
115  if (it.first->getValue() != value) {
116  it.first->setValue(value);
117 
118  // If the attributes have changed, the dictionary is invalidated.
119  dictionarySorted.setPointer(nullptr);
120  }
121  return oldValue;
122  }
123  // Perform a string lookup to insert the new attribute into its sorted
124  // position.
125  if (isSorted())
126  it = findAttr(*this, name.strref());
127  attrs.insert(it.first, {name, value});
128  // Invalidate the dictionary. Return null as there was no previous value.
129  dictionarySorted.setPointer(nullptr);
130  return Attribute();
131 }
132 
133 Attribute NamedAttrList::set(StringRef name, Attribute value) {
134  assert(value && "attributes may never be null");
135  return set(mlir::StringAttr::get(value.getContext(), name), value);
136 }
137 
138 Attribute
139 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
140  // Erasing does not affect the sorted property.
141  Attribute attr = it->getValue();
142  attrs.erase(it);
143  dictionarySorted.setPointer(nullptr);
144  return attr;
145 }
146 
147 Attribute NamedAttrList::erase(StringAttr name) {
148  auto it = findAttr(*this, name);
149  return it.second ? eraseImpl(it.first) : Attribute();
150 }
151 
153  auto it = findAttr(*this, name);
154  return it.second ? eraseImpl(it.first) : Attribute();
155 }
156 
159  assign(rhs.begin(), rhs.end());
160  return *this;
161 }
162 
163 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
164 
165 //===----------------------------------------------------------------------===//
166 // OperationState
167 //===----------------------------------------------------------------------===//
168 
169 OperationState::OperationState(Location location, StringRef name)
170  : location(location), name(name, location->getContext()) {}
171 
173  : location(location), name(name) {}
174 
176  ValueRange operands, TypeRange types,
177  ArrayRef<NamedAttribute> attributes,
178  BlockRange successors,
179  MutableArrayRef<std::unique_ptr<Region>> regions)
180  : location(location), name(name),
181  operands(operands.begin(), operands.end()),
182  types(types.begin(), types.end()),
183  attributes(attributes.begin(), attributes.end()),
184  successors(successors.begin(), successors.end()) {
185  for (std::unique_ptr<Region> &r : regions)
186  this->regions.push_back(std::move(r));
187 }
188 OperationState::OperationState(Location location, StringRef name,
189  ValueRange operands, TypeRange types,
190  ArrayRef<NamedAttribute> attributes,
191  BlockRange successors,
192  MutableArrayRef<std::unique_ptr<Region>> regions)
193  : OperationState(location, OperationName(name, location.getContext()),
194  operands, types, attributes, successors, regions) {}
195 
197  if (properties)
198  propertiesDeleter(properties);
199 }
200 
203  if (LLVM_UNLIKELY(propertiesAttr)) {
204  assert(!properties);
206  }
207  if (properties)
208  propertiesSetter(op->getPropertiesStorage(), properties);
209  return success();
210 }
211 
213  operands.append(newOperands.begin(), newOperands.end());
214 }
215 
217  successors.append(newSuccessors.begin(), newSuccessors.end());
218 }
219 
221  regions.emplace_back(new Region);
222  return regions.back().get();
223 }
224 
225 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
226  regions.push_back(std::move(region));
227 }
228 
230  MutableArrayRef<std::unique_ptr<Region>> regions) {
231  for (std::unique_ptr<Region> &region : regions)
232  addRegion(std::move(region));
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // OperandStorage
237 //===----------------------------------------------------------------------===//
238 
240  OpOperand *trailingOperands,
241  ValueRange values)
242  : isStorageDynamic(false), operandStorage(trailingOperands) {
243  numOperands = capacity = values.size();
244  for (unsigned i = 0; i < numOperands; ++i)
245  new (&operandStorage[i]) OpOperand(owner, values[i]);
246 }
247 
249  for (auto &operand : getOperands())
250  operand.~OpOperand();
251 
252  // If the storage is dynamic, deallocate it.
253  if (isStorageDynamic)
254  free(operandStorage);
255 }
256 
257 /// Replace the operands contained in the storage with the ones provided in
258 /// 'values'.
260  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
261  for (unsigned i = 0, e = values.size(); i != e; ++i)
262  storageOperands[i].set(values[i]);
263 }
264 
265 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
266 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
267 /// than the range pointed to by 'start'+'length'.
268 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
269  unsigned length, ValueRange operands) {
270  // If the new size is the same, we can update inplace.
271  unsigned newSize = operands.size();
272  if (newSize == length) {
273  MutableArrayRef<OpOperand> storageOperands = getOperands();
274  for (unsigned i = 0, e = length; i != e; ++i)
275  storageOperands[start + i].set(operands[i]);
276  return;
277  }
278  // If the new size is greater, remove the extra operands and set the rest
279  // inplace.
280  if (newSize < length) {
281  eraseOperands(start + operands.size(), length - newSize);
282  setOperands(owner, start, newSize, operands);
283  return;
284  }
285  // Otherwise, the new size is greater so we need to grow the storage.
286  auto storageOperands = resize(owner, size() + (newSize - length));
287 
288  // Shift operands to the right to make space for the new operands.
289  unsigned rotateSize = storageOperands.size() - (start + length);
290  auto rbegin = storageOperands.rbegin();
291  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
292 
293  // Update the operands inplace.
294  for (unsigned i = 0, e = operands.size(); i != e; ++i)
295  storageOperands[start + i].set(operands[i]);
296 }
297 
298 /// Erase an operand held by the storage.
299 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
300  MutableArrayRef<OpOperand> operands = getOperands();
301  assert((start + length) <= operands.size());
302  numOperands -= length;
303 
304  // Shift all operands down if the operand to remove is not at the end.
305  if (start != numOperands) {
306  auto *indexIt = std::next(operands.begin(), start);
307  std::rotate(indexIt, std::next(indexIt, length), operands.end());
308  }
309  for (unsigned i = 0; i != length; ++i)
310  operands[numOperands + i].~OpOperand();
311 }
312 
313 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
314  MutableArrayRef<OpOperand> operands = getOperands();
315  assert(eraseIndices.size() == operands.size());
316 
317  // Check that at least one operand is erased.
318  int firstErasedIndice = eraseIndices.find_first();
319  if (firstErasedIndice == -1)
320  return;
321 
322  // Shift all of the removed operands to the end, and destroy them.
323  numOperands = firstErasedIndice;
324  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
325  if (!eraseIndices.test(i))
326  operands[numOperands++] = std::move(operands[i]);
327  for (OpOperand &operand : operands.drop_front(numOperands))
328  operand.~OpOperand();
329 }
330 
331 /// Resize the storage to the given size. Returns the array containing the new
332 /// operands.
333 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
334  unsigned newSize) {
335  // If the number of operands is less than or equal to the current amount, we
336  // can just update in place.
337  MutableArrayRef<OpOperand> origOperands = getOperands();
338  if (newSize <= numOperands) {
339  // If the number of new size is less than the current, remove any extra
340  // operands.
341  for (unsigned i = newSize; i != numOperands; ++i)
342  origOperands[i].~OpOperand();
343  numOperands = newSize;
344  return origOperands.take_front(newSize);
345  }
346 
347  // If the new size is within the original inline capacity, grow inplace.
348  if (newSize <= capacity) {
349  OpOperand *opBegin = origOperands.data();
350  for (unsigned e = newSize; numOperands != e; ++numOperands)
351  new (&opBegin[numOperands]) OpOperand(owner);
352  return MutableArrayRef<OpOperand>(opBegin, newSize);
353  }
354 
355  // Otherwise, we need to allocate a new storage.
356  unsigned newCapacity =
357  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
358  OpOperand *newOperandStorage =
359  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
360 
361  // Move the current operands to the new storage.
362  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
363  std::uninitialized_move(origOperands.begin(), origOperands.end(),
364  newOperands.begin());
365 
366  // Destroy the original operands.
367  for (auto &operand : origOperands)
368  operand.~OpOperand();
369 
370  // Initialize any new operands.
371  for (unsigned e = newSize; numOperands != e; ++numOperands)
372  new (&newOperands[numOperands]) OpOperand(owner);
373 
374  // If the current storage is dynamic, free it.
375  if (isStorageDynamic)
376  free(operandStorage);
377 
378  // Update the storage representation to use the new dynamic storage.
379  operandStorage = newOperandStorage;
380  capacity = newCapacity;
381  isStorageDynamic = true;
382  return newOperands;
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // Operation Value-Iterators
387 //===----------------------------------------------------------------------===//
388 
389 //===----------------------------------------------------------------------===//
390 // OperandRange
391 
393  assert(!empty() && "range must not be empty");
394  return base->getOperandNumber();
395 }
396 
398  return OperandRangeRange(*this, segmentSizes);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // OperandRangeRange
403 
405  Attribute operandSegments)
406  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
407  llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
408 }
409 
411  const OwnerT &owner = getBase();
412  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
413  return OperandRange(owner.first,
414  std::accumulate(sizeData.begin(), sizeData.end(), 0));
415 }
416 
417 OperandRange OperandRangeRange::dereference(const OwnerT &object,
418  ptrdiff_t index) {
419  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
420  uint32_t startIndex =
421  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
422  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // MutableOperandRange
427 
428 /// Construct a new mutable range from the given operand, operand start index,
429 /// and range length.
431  Operation *owner, unsigned start, unsigned length,
432  ArrayRef<OperandSegment> operandSegments)
433  : owner(owner), start(start), length(length),
434  operandSegments(operandSegments.begin(), operandSegments.end()) {
435  assert((start + length) <= owner->getNumOperands() && "invalid range");
436 }
438  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
439 
440 /// Construct a new mutable range for the given OpOperand.
442  : MutableOperandRange(opOperand.getOwner(),
443  /*start=*/opOperand.getOperandNumber(),
444  /*length=*/1) {}
445 
446 /// Slice this range into a sub range, with the additional operand segment.
448 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
449  std::optional<OperandSegment> segment) const {
450  assert((subStart + subLen) <= length && "invalid sub-range");
451  MutableOperandRange subSlice(owner, start + subStart, subLen,
452  operandSegments);
453  if (segment)
454  subSlice.operandSegments.push_back(*segment);
455  return subSlice;
456 }
457 
458 /// Append the given values to the range.
460  if (values.empty())
461  return;
462  owner->insertOperands(start + length, values);
463  updateLength(length + values.size());
464 }
465 
466 /// Assign this range to the given values.
468  owner->setOperands(start, length, values);
469  if (length != values.size())
470  updateLength(/*newLength=*/values.size());
471 }
472 
473 /// Assign the range to the given value.
475  if (length == 1) {
476  owner->setOperand(start, value);
477  } else {
478  owner->setOperands(start, length, value);
479  updateLength(/*newLength=*/1);
480  }
481 }
482 
483 /// Erase the operands within the given sub-range.
484 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
485  assert((subStart + subLen) <= length && "invalid sub-range");
486  if (length == 0)
487  return;
488  owner->eraseOperands(start + subStart, subLen);
489  updateLength(length - subLen);
490 }
491 
492 /// Clear this range and erase all of the operands.
494  if (length != 0) {
495  owner->eraseOperands(start, length);
496  updateLength(/*newLength=*/0);
497  }
498 }
499 
500 /// Explicit conversion to an OperandRange.
502  return owner->getOperands().slice(start, length);
503 }
504 
505 /// Allow implicit conversion to an OperandRange.
506 MutableOperandRange::operator OperandRange() const {
507  return getAsOperandRange();
508 }
509 
510 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
511  return owner->getOpOperands().slice(start, length);
512 }
513 
516  return MutableOperandRangeRange(*this, segmentSizes);
517 }
518 
519 /// Update the length of this range to the one provided.
520 void MutableOperandRange::updateLength(unsigned newLength) {
521  int32_t diff = int32_t(newLength) - int32_t(length);
522  length = newLength;
523 
524  // Update any of the provided segment attributes.
525  for (OperandSegment &segment : operandSegments) {
526  auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
527  SmallVector<int32_t, 8> segments(attr.asArrayRef());
528  segments[segment.first] += diff;
529  segment.second.setValue(
530  DenseI32ArrayAttr::get(attr.getContext(), segments));
531  owner->setAttr(segment.second.getName(), segment.second.getValue());
532  }
533 }
534 
536  assert(index < length && "index is out of bounds");
537  return owner->getOpOperand(start + index);
538 }
539 
541  return owner->getOpOperands().slice(start, length).begin();
542 }
543 
545  return owner->getOpOperands().slice(start, length).end();
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // MutableOperandRangeRange
550 
552  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
554  OwnerT(operands, operandSegmentAttr), 0,
555  llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
556 }
557 
559  return getBase().first;
560 }
561 
562 MutableOperandRangeRange::operator OperandRangeRange() const {
563  return OperandRangeRange(getBase().first, getBase().second.getValue());
564 }
565 
566 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
567  ptrdiff_t index) {
568  ArrayRef<int32_t> sizeData =
569  llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
570  uint32_t startIndex =
571  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
572  return object.first.slice(
573  startIndex, *(sizeData.begin() + index),
574  MutableOperandRange::OperandSegment(index, object.second));
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // ResultRange
579 
581  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
582  1) {}
583 
585  return {use_begin(), use_end()};
586 }
588  return use_iterator(*this);
589 }
591  return use_iterator(*this, /*end=*/true);
592 }
594  return {user_begin(), user_end()};
595 }
597  return user_iterator(use_begin());
598 }
600  return user_iterator(use_end());
601 }
602 
604  : it(end ? results.end() : results.begin()), endIt(results.end()) {
605  // Only initialize current use if there are results/can be uses.
606  if (it != endIt)
607  skipOverResultsWithNoUsers();
608 }
609 
611  // We increment over uses, if we reach the last use then move to next
612  // result.
613  if (use != (*it).use_end())
614  ++use;
615  if (use == (*it).use_end()) {
616  ++it;
617  skipOverResultsWithNoUsers();
618  }
619  return *this;
620 }
621 
622 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
623  while (it != endIt && (*it).use_empty())
624  ++it;
625 
626  // If we are at the last result, then set use to first use of
627  // first result (sentinel value used for end).
628  if (it == endIt)
629  use = {};
630  else
631  use = (*it).use_begin();
632 }
633 
636 }
637 
639  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
640  replaceUsesWithIf(op->getResults(), shouldReplace);
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // ValueRange
645 
647  : ValueRange(values.data(), values.size()) {}
649  : ValueRange(values.begin().getBase(), values.size()) {}
651  : ValueRange(values.getBase(), values.size()) {}
652 
653 /// See `llvm::detail::indexed_accessor_range_base` for details.
654 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
655  ptrdiff_t index) {
656  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
657  return {value + index};
658  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
659  return {operand + index};
660  return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
661 }
662 /// See `llvm::detail::indexed_accessor_range_base` for details.
663 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
664  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
665  return value[index];
666  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
667  return operand[index].get();
668  return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // Operation Equivalency
673 //===----------------------------------------------------------------------===//
674 
676  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
677  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
678  // Hash operations based upon their:
679  // - Operation Name
680  // - Attributes
681  // - Result Types
682  llvm::hash_code hash =
683  llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(),
684  op->getResultTypes(), op->hashProperties());
685 
686  // - Location if required
687  if (!(flags & Flags::IgnoreLocations))
688  hash = llvm::hash_combine(hash, op->getLoc());
689 
690  // - Operands
692  op->getNumOperands() > 0) {
693  size_t operandHash = hashOperands(op->getOperand(0));
694  for (auto operand : op->getOperands().drop_front())
695  operandHash += hashOperands(operand);
696  hash = llvm::hash_combine(hash, operandHash);
697  } else {
698  for (Value operand : op->getOperands())
699  hash = llvm::hash_combine(hash, hashOperands(operand));
700  }
701 
702  // - Results
703  for (Value result : op->getResults())
704  hash = llvm::hash_combine(hash, hashResults(result));
705  return hash;
706 }
707 
709  Region *lhs, Region *rhs,
710  function_ref<LogicalResult(Value, Value)> checkEquivalent,
711  function_ref<void(Value, Value)> markEquivalent,
714  checkCommutativeEquivalent) {
715  DenseMap<Block *, Block *> blocksMap;
716  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
717  // Check block arguments.
718  if (lBlock.getNumArguments() != rBlock.getNumArguments())
719  return false;
720 
721  // Map the two blocks.
722  auto insertion = blocksMap.insert({&lBlock, &rBlock});
723  if (insertion.first->getSecond() != &rBlock)
724  return false;
725 
726  for (auto argPair :
727  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
728  Value curArg = std::get<0>(argPair);
729  Value otherArg = std::get<1>(argPair);
730  if (curArg.getType() != otherArg.getType())
731  return false;
732  if (!(flags & OperationEquivalence::IgnoreLocations) &&
733  curArg.getLoc() != otherArg.getLoc())
734  return false;
735  // Corresponding bbArgs are equivalent.
736  if (markEquivalent)
737  markEquivalent(curArg, otherArg);
738  }
739 
740  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
741  // Check for op equality (recursively).
742  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
743  markEquivalent, flags,
744  checkCommutativeEquivalent))
745  return false;
746  // Check successor mapping.
747  for (auto successorsPair :
748  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
749  Block *curSuccessor = std::get<0>(successorsPair);
750  Block *otherSuccessor = std::get<1>(successorsPair);
751  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
752  if (insertion.first->getSecond() != otherSuccessor)
753  return false;
754  }
755  return true;
756  };
757  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
758  };
759  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
760 }
761 
762 // Value equivalence cache to be used with `isRegionEquivalentTo` and
763 // `isEquivalentTo`.
766  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
767  return success(lhsValue == rhsValue ||
768  equivalentValues.lookup(lhsValue) == rhsValue);
769  }
771  ValueRange rhsRange) {
772  // Handle simple case where sizes mismatch.
773  if (lhsRange.size() != rhsRange.size())
774  return failure();
775 
776  // Handle where operands in order are equivalent.
777  auto lhsIt = lhsRange.begin();
778  auto rhsIt = rhsRange.begin();
779  for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
780  if (failed(checkEquivalent(*lhsIt, *rhsIt)))
781  break;
782  }
783  if (lhsIt == lhsRange.end())
784  return success();
785 
786  // Handle another simple case where operands are just a permutation.
787  // Note: This is not sufficient, this handles simple cases relatively
788  // cheaply.
789  auto sortValues = [](ValueRange values) {
790  SmallVector<Value> sortedValues = llvm::to_vector(values);
791  llvm::sort(sortedValues, [](Value a, Value b) {
792  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
793  });
794  return sortedValues;
795  };
796  auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
797  auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
798  return success(lhsSorted == rhsSorted);
799  }
800  void markEquivalent(Value lhsResult, Value rhsResult) {
801  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
802  // Make sure that the value was not already marked equivalent to some other
803  // value.
804  (void)insertion;
805  assert(insertion.first->second == rhsResult &&
806  "inconsistent OperationEquivalence state");
807  }
808 };
809 
810 /*static*/ bool
813  ValueEquivalenceCache cache;
814  return isRegionEquivalentTo(
815  lhs, rhs,
816  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
817  return cache.checkEquivalent(lhsValue, rhsValue);
818  },
819  [&](Value lhsResult, Value rhsResult) {
820  cache.markEquivalent(lhsResult, rhsResult);
821  },
822  flags,
823  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
824  return cache.checkCommutativeEquivalent(lhs, rhs);
825  });
826 }
827 
829  Operation *lhs, Operation *rhs,
830  function_ref<LogicalResult(Value, Value)> checkEquivalent,
831  function_ref<void(Value, Value)> markEquivalent, Flags flags,
833  checkCommutativeEquivalent) {
834  if (lhs == rhs)
835  return true;
836 
837  // 1. Compare the operation properties.
838  if (lhs->getName() != rhs->getName() ||
839  lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
840  lhs->getNumRegions() != rhs->getNumRegions() ||
841  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
842  lhs->getNumOperands() != rhs->getNumOperands() ||
843  lhs->getNumResults() != rhs->getNumResults() ||
845  rhs->getPropertiesStorage()))
846  return false;
847  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
848  return false;
849 
850  // 2. Compare operands.
851  if (checkCommutativeEquivalent &&
853  auto lhsRange = lhs->getOperands();
854  auto rhsRange = rhs->getOperands();
855  if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
856  return false;
857  } else {
858  // Check pair wise for equivalence.
859  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
860  Value curArg = std::get<0>(operandPair);
861  Value otherArg = std::get<1>(operandPair);
862  if (curArg == otherArg)
863  continue;
864  if (curArg.getType() != otherArg.getType())
865  return false;
866  if (failed(checkEquivalent(curArg, otherArg)))
867  return false;
868  }
869  }
870 
871  // 3. Compare result types and mark results as equivalent.
872  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
873  Value curArg = std::get<0>(resultPair);
874  Value otherArg = std::get<1>(resultPair);
875  if (curArg.getType() != otherArg.getType())
876  return false;
877  if (markEquivalent)
878  markEquivalent(curArg, otherArg);
879  }
880 
881  // 4. Compare regions.
882  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
883  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
884  &std::get<1>(regionPair), checkEquivalent,
885  markEquivalent, flags))
886  return false;
887 
888  return true;
889 }
890 
892  Operation *rhs,
893  Flags flags) {
894  ValueEquivalenceCache cache;
896  lhs, rhs,
897  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
898  return cache.checkEquivalent(lhsValue, rhsValue);
899  },
900  [&](Value lhsResult, Value rhsResult) {
901  cache.markEquivalent(lhsResult, rhsResult);
902  },
903  flags,
904  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
905  return cache.checkCommutativeEquivalent(lhs, rhs);
906  });
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // OperationFingerPrint
911 //===----------------------------------------------------------------------===//
912 
913 template <typename T>
914 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
915  hasher.update(
916  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
917 }
918 
920  bool includeNested) {
921  llvm::SHA1 hasher;
922 
923  // Helper function that hashes an operation based on its mutable bits:
924  auto addOperationToHash = [&](Operation *op) {
925  // - Operation pointer
926  addDataToHash(hasher, op);
927  // - Parent operation pointer (to take into account the nesting structure)
928  if (op != topOp)
929  addDataToHash(hasher, op->getParentOp());
930  // - Attributes
931  addDataToHash(hasher, op->getRawDictionaryAttrs());
932  // - Properties
933  addDataToHash(hasher, op->hashProperties());
934  // - Blocks in Regions
935  for (Region &region : op->getRegions()) {
936  for (Block &block : region) {
937  addDataToHash(hasher, &block);
938  for (BlockArgument arg : block.getArguments())
939  addDataToHash(hasher, arg);
940  }
941  }
942  // - Location
943  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
944  // - Operands
945  for (Value operand : op->getOperands())
946  addDataToHash(hasher, operand);
947  // - Successors
948  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
949  addDataToHash(hasher, op->getSuccessor(i));
950  // - Result types
951  for (Type t : op->getResultTypes())
952  addDataToHash(hasher, t);
953  };
954 
955  if (includeNested)
956  topOp->walk(addOperationToHash);
957  else
958  addOperationToHash(topOp);
959 
960  hash = hasher.result();
961 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:319
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:125
BlockArgListType getArguments()
Definition: Block.h:84
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:104
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:206
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments=std::nullopt)
Construct a new mutable range from the given operand, operand start index, and range length.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:120
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:82
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
unsigned getNumSuccessors()
Definition: Operation.h:702
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:355
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:504
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
Definition: Operation.cpp:355
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
SuccessorRange getSuccessors()
Definition: Operation.h:699
result_range getResults()
Definition: Operation.h:410
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
Definition: Operation.cpp:370
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:345
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:317
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:298
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:281
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:262
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:386
ValueRange(Arg &&arg)
Definition: ValueRange.h:394
An iterator over the users of an IRObject.
Definition: UseDefLists.h:344
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:243
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
This class provides the implementation for an operation result.
Definition: Value.h:368
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
Include the generated interface declarations.
Definition: CallGraph.h:229
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const