MLIR  17.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
28 
30  assign(attributes.begin(), attributes.end());
31 }
32 
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34  : NamedAttrList(attributes ? attributes.getValue()
35  : ArrayRef<NamedAttribute>()) {
36  dictionarySorted.setPointerAndInt(attributes, true);
37 }
38 
40  assign(inStart, inEnd);
41 }
42 
44 
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46  std::optional<NamedAttribute> duplicate =
47  DictionaryAttr::findDuplicate(attrs, isSorted());
48  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49  // state.
50  if (!isSorted())
51  dictionarySorted.setPointerAndInt(nullptr, true);
52  return duplicate;
53 }
54 
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56  if (!isSorted()) {
57  DictionaryAttr::sortInPlace(attrs);
58  dictionarySorted.setPointerAndInt(nullptr, true);
59  }
60  if (!dictionarySorted.getPointer())
61  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62  return dictionarySorted.getPointer().cast<DictionaryAttr>();
63 }
64 
65 /// Add an attribute with the specified name.
66 void NamedAttrList::append(StringRef name, Attribute attr) {
67  append(StringAttr::get(attr.getContext(), name), attr);
68 }
69 
70 /// Replaces the attributes with new list of attributes.
72  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
73  dictionarySorted.setPointerAndInt(nullptr, true);
74 }
75 
77  if (isSorted())
78  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
79  dictionarySorted.setPointer(nullptr);
80  attrs.push_back(newAttribute);
81 }
82 
83 /// Return the specified attribute if present, null otherwise.
84 Attribute NamedAttrList::get(StringRef name) const {
85  auto it = findAttr(*this, name);
86  return it.second ? it.first->getValue() : Attribute();
87 }
88 Attribute NamedAttrList::get(StringAttr name) const {
89  auto it = findAttr(*this, name);
90  return it.second ? it.first->getValue() : Attribute();
91 }
92 
93 /// Return the specified named attribute if present, std::nullopt otherwise.
94 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
95  auto it = findAttr(*this, name);
96  return it.second ? *it.first : std::optional<NamedAttribute>();
97 }
98 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
99  auto it = findAttr(*this, name);
100  return it.second ? *it.first : std::optional<NamedAttribute>();
101 }
102 
103 /// If the an attribute exists with the specified name, change it to the new
104 /// value. Otherwise, add a new attribute with the specified name/value.
105 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
106  assert(value && "attributes may never be null");
107 
108  // Look for an existing attribute with the given name, and set its value
109  // in-place. Return the previous value of the attribute, if there was one.
110  auto it = findAttr(*this, name);
111  if (it.second) {
112  // Update the existing attribute by swapping out the old value for the new
113  // value. Return the old value.
114  Attribute oldValue = it.first->getValue();
115  if (it.first->getValue() != value) {
116  it.first->setValue(value);
117 
118  // If the attributes have changed, the dictionary is invalidated.
119  dictionarySorted.setPointer(nullptr);
120  }
121  return oldValue;
122  }
123  // Perform a string lookup to insert the new attribute into its sorted
124  // position.
125  if (isSorted())
126  it = findAttr(*this, name.strref());
127  attrs.insert(it.first, {name, value});
128  // Invalidate the dictionary. Return null as there was no previous value.
129  dictionarySorted.setPointer(nullptr);
130  return Attribute();
131 }
132 
133 Attribute NamedAttrList::set(StringRef name, Attribute value) {
134  assert(value && "attributes may never be null");
135  return set(mlir::StringAttr::get(value.getContext(), name), value);
136 }
137 
138 Attribute
139 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
140  // Erasing does not affect the sorted property.
141  Attribute attr = it->getValue();
142  attrs.erase(it);
143  dictionarySorted.setPointer(nullptr);
144  return attr;
145 }
146 
147 Attribute NamedAttrList::erase(StringAttr name) {
148  auto it = findAttr(*this, name);
149  return it.second ? eraseImpl(it.first) : Attribute();
150 }
151 
153  auto it = findAttr(*this, name);
154  return it.second ? eraseImpl(it.first) : Attribute();
155 }
156 
159  assign(rhs.begin(), rhs.end());
160  return *this;
161 }
162 
163 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
164 
165 //===----------------------------------------------------------------------===//
166 // OperationState
167 //===----------------------------------------------------------------------===//
168 
169 OperationState::OperationState(Location location, StringRef name)
170  : location(location), name(name, location->getContext()) {}
171 
173  : location(location), name(name) {}
174 
176  ValueRange operands, TypeRange types,
177  ArrayRef<NamedAttribute> attributes,
178  BlockRange successors,
179  MutableArrayRef<std::unique_ptr<Region>> regions)
180  : location(location), name(name),
181  operands(operands.begin(), operands.end()),
182  types(types.begin(), types.end()),
183  attributes(attributes.begin(), attributes.end()),
184  successors(successors.begin(), successors.end()) {
185  for (std::unique_ptr<Region> &r : regions)
186  this->regions.push_back(std::move(r));
187 }
188 OperationState::OperationState(Location location, StringRef name,
189  ValueRange operands, TypeRange types,
190  ArrayRef<NamedAttribute> attributes,
191  BlockRange successors,
192  MutableArrayRef<std::unique_ptr<Region>> regions)
193  : OperationState(location, OperationName(name, location.getContext()),
194  operands, types, attributes, successors, regions) {}
195 
197  operands.append(newOperands.begin(), newOperands.end());
198 }
199 
201  successors.append(newSuccessors.begin(), newSuccessors.end());
202 }
203 
205  regions.emplace_back(new Region);
206  return regions.back().get();
207 }
208 
209 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
210  regions.push_back(std::move(region));
211 }
212 
214  MutableArrayRef<std::unique_ptr<Region>> regions) {
215  for (std::unique_ptr<Region> &region : regions)
216  addRegion(std::move(region));
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // OperandStorage
221 //===----------------------------------------------------------------------===//
222 
224  OpOperand *trailingOperands,
225  ValueRange values)
226  : isStorageDynamic(false), operandStorage(trailingOperands) {
227  numOperands = capacity = values.size();
228  for (unsigned i = 0; i < numOperands; ++i)
229  new (&operandStorage[i]) OpOperand(owner, values[i]);
230 }
231 
233  for (auto &operand : getOperands())
234  operand.~OpOperand();
235 
236  // If the storage is dynamic, deallocate it.
237  if (isStorageDynamic)
238  free(operandStorage);
239 }
240 
241 /// Replace the operands contained in the storage with the ones provided in
242 /// 'values'.
244  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
245  for (unsigned i = 0, e = values.size(); i != e; ++i)
246  storageOperands[i].set(values[i]);
247 }
248 
249 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
250 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
251 /// than the range pointed to by 'start'+'length'.
252 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
253  unsigned length, ValueRange operands) {
254  // If the new size is the same, we can update inplace.
255  unsigned newSize = operands.size();
256  if (newSize == length) {
257  MutableArrayRef<OpOperand> storageOperands = getOperands();
258  for (unsigned i = 0, e = length; i != e; ++i)
259  storageOperands[start + i].set(operands[i]);
260  return;
261  }
262  // If the new size is greater, remove the extra operands and set the rest
263  // inplace.
264  if (newSize < length) {
265  eraseOperands(start + operands.size(), length - newSize);
266  setOperands(owner, start, newSize, operands);
267  return;
268  }
269  // Otherwise, the new size is greater so we need to grow the storage.
270  auto storageOperands = resize(owner, size() + (newSize - length));
271 
272  // Shift operands to the right to make space for the new operands.
273  unsigned rotateSize = storageOperands.size() - (start + length);
274  auto rbegin = storageOperands.rbegin();
275  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
276 
277  // Update the operands inplace.
278  for (unsigned i = 0, e = operands.size(); i != e; ++i)
279  storageOperands[start + i].set(operands[i]);
280 }
281 
282 /// Erase an operand held by the storage.
283 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
284  MutableArrayRef<OpOperand> operands = getOperands();
285  assert((start + length) <= operands.size());
286  numOperands -= length;
287 
288  // Shift all operands down if the operand to remove is not at the end.
289  if (start != numOperands) {
290  auto *indexIt = std::next(operands.begin(), start);
291  std::rotate(indexIt, std::next(indexIt, length), operands.end());
292  }
293  for (unsigned i = 0; i != length; ++i)
294  operands[numOperands + i].~OpOperand();
295 }
296 
297 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
298  MutableArrayRef<OpOperand> operands = getOperands();
299  assert(eraseIndices.size() == operands.size());
300 
301  // Check that at least one operand is erased.
302  int firstErasedIndice = eraseIndices.find_first();
303  if (firstErasedIndice == -1)
304  return;
305 
306  // Shift all of the removed operands to the end, and destroy them.
307  numOperands = firstErasedIndice;
308  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
309  if (!eraseIndices.test(i))
310  operands[numOperands++] = std::move(operands[i]);
311  for (OpOperand &operand : operands.drop_front(numOperands))
312  operand.~OpOperand();
313 }
314 
315 /// Resize the storage to the given size. Returns the array containing the new
316 /// operands.
317 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
318  unsigned newSize) {
319  // If the number of operands is less than or equal to the current amount, we
320  // can just update in place.
321  MutableArrayRef<OpOperand> origOperands = getOperands();
322  if (newSize <= numOperands) {
323  // If the number of new size is less than the current, remove any extra
324  // operands.
325  for (unsigned i = newSize; i != numOperands; ++i)
326  origOperands[i].~OpOperand();
327  numOperands = newSize;
328  return origOperands.take_front(newSize);
329  }
330 
331  // If the new size is within the original inline capacity, grow inplace.
332  if (newSize <= capacity) {
333  OpOperand *opBegin = origOperands.data();
334  for (unsigned e = newSize; numOperands != e; ++numOperands)
335  new (&opBegin[numOperands]) OpOperand(owner);
336  return MutableArrayRef<OpOperand>(opBegin, newSize);
337  }
338 
339  // Otherwise, we need to allocate a new storage.
340  unsigned newCapacity =
341  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
342  OpOperand *newOperandStorage =
343  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
344 
345  // Move the current operands to the new storage.
346  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
347  std::uninitialized_move(origOperands.begin(), origOperands.end(),
348  newOperands.begin());
349 
350  // Destroy the original operands.
351  for (auto &operand : origOperands)
352  operand.~OpOperand();
353 
354  // Initialize any new operands.
355  for (unsigned e = newSize; numOperands != e; ++numOperands)
356  new (&newOperands[numOperands]) OpOperand(owner);
357 
358  // If the current storage is dynamic, free it.
359  if (isStorageDynamic)
360  free(operandStorage);
361 
362  // Update the storage representation to use the new dynamic storage.
363  operandStorage = newOperandStorage;
364  capacity = newCapacity;
365  isStorageDynamic = true;
366  return newOperands;
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // Operation Value-Iterators
371 //===----------------------------------------------------------------------===//
372 
373 //===----------------------------------------------------------------------===//
374 // OperandRange
375 
377  assert(!empty() && "range must not be empty");
378  return base->getOperandNumber();
379 }
380 
382  return OperandRangeRange(*this, segmentSizes);
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // OperandRangeRange
387 
389  Attribute operandSegments)
390  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
391  operandSegments.cast<DenseI32ArrayAttr>().size()) {}
392 
394  const OwnerT &owner = getBase();
395  ArrayRef<int32_t> sizeData = owner.second.cast<DenseI32ArrayAttr>();
396  return OperandRange(owner.first,
397  std::accumulate(sizeData.begin(), sizeData.end(), 0));
398 }
399 
400 OperandRange OperandRangeRange::dereference(const OwnerT &object,
401  ptrdiff_t index) {
402  ArrayRef<int32_t> sizeData = object.second.cast<DenseI32ArrayAttr>();
403  uint32_t startIndex =
404  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
405  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // MutableOperandRange
410 
411 /// Construct a new mutable range from the given operand, operand start index,
412 /// and range length.
414  Operation *owner, unsigned start, unsigned length,
415  ArrayRef<OperandSegment> operandSegments)
416  : owner(owner), start(start), length(length),
417  operandSegments(operandSegments.begin(), operandSegments.end()) {
418  assert((start + length) <= owner->getNumOperands() && "invalid range");
419 }
421  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
422 
423 /// Slice this range into a sub range, with the additional operand segment.
425 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
426  std::optional<OperandSegment> segment) const {
427  assert((subStart + subLen) <= length && "invalid sub-range");
428  MutableOperandRange subSlice(owner, start + subStart, subLen,
429  operandSegments);
430  if (segment)
431  subSlice.operandSegments.push_back(*segment);
432  return subSlice;
433 }
434 
435 /// Append the given values to the range.
437  if (values.empty())
438  return;
439  owner->insertOperands(start + length, values);
440  updateLength(length + values.size());
441 }
442 
443 /// Assign this range to the given values.
445  owner->setOperands(start, length, values);
446  if (length != values.size())
447  updateLength(/*newLength=*/values.size());
448 }
449 
450 /// Assign the range to the given value.
452  if (length == 1) {
453  owner->setOperand(start, value);
454  } else {
455  owner->setOperands(start, length, value);
456  updateLength(/*newLength=*/1);
457  }
458 }
459 
460 /// Erase the operands within the given sub-range.
461 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
462  assert((subStart + subLen) <= length && "invalid sub-range");
463  if (length == 0)
464  return;
465  owner->eraseOperands(start + subStart, subLen);
466  updateLength(length - subLen);
467 }
468 
469 /// Clear this range and erase all of the operands.
471  if (length != 0) {
472  owner->eraseOperands(start, length);
473  updateLength(/*newLength=*/0);
474  }
475 }
476 
477 /// Allow implicit conversion to an OperandRange.
478 MutableOperandRange::operator OperandRange() const {
479  return owner->getOperands().slice(start, length);
480 }
481 
484  return MutableOperandRangeRange(*this, segmentSizes);
485 }
486 
487 /// Update the length of this range to the one provided.
488 void MutableOperandRange::updateLength(unsigned newLength) {
489  int32_t diff = int32_t(newLength) - int32_t(length);
490  length = newLength;
491 
492  // Update any of the provided segment attributes.
493  for (OperandSegment &segment : operandSegments) {
494  auto attr = segment.second.getValue().cast<DenseI32ArrayAttr>();
495  SmallVector<int32_t, 8> segments(attr.asArrayRef());
496  segments[segment.first] += diff;
497  segment.second.setValue(
498  DenseI32ArrayAttr::get(attr.getContext(), segments));
499  owner->setAttr(segment.second.getName(), segment.second.getValue());
500  }
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // MutableOperandRangeRange
505 
507  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
509  OwnerT(operands, operandSegmentAttr), 0,
510  operandSegmentAttr.getValue().cast<DenseI32ArrayAttr>().size()) {}
511 
513  return getBase().first;
514 }
515 
516 MutableOperandRangeRange::operator OperandRangeRange() const {
517  return OperandRangeRange(getBase().first, getBase().second.getValue());
518 }
519 
520 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
521  ptrdiff_t index) {
522  ArrayRef<int32_t> sizeData =
523  object.second.getValue().cast<DenseI32ArrayAttr>();
524  uint32_t startIndex =
525  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
526  return object.first.slice(
527  startIndex, *(sizeData.begin() + index),
528  MutableOperandRange::OperandSegment(index, object.second));
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // ResultRange
533 
535  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
536  1) {}
537 
539  return {use_begin(), use_end()};
540 }
542  return use_iterator(*this);
543 }
545  return use_iterator(*this, /*end=*/true);
546 }
548  return {user_begin(), user_end()};
549 }
551  return user_iterator(use_begin());
552 }
554  return user_iterator(use_end());
555 }
556 
558  : it(end ? results.end() : results.begin()), endIt(results.end()) {
559  // Only initialize current use if there are results/can be uses.
560  if (it != endIt)
561  skipOverResultsWithNoUsers();
562 }
563 
565  // We increment over uses, if we reach the last use then move to next
566  // result.
567  if (use != (*it).use_end())
568  ++use;
569  if (use == (*it).use_end()) {
570  ++it;
571  skipOverResultsWithNoUsers();
572  }
573  return *this;
574 }
575 
576 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
577  while (it != endIt && (*it).use_empty())
578  ++it;
579 
580  // If we are at the last result, then set use to first use of
581  // first result (sentinel value used for end).
582  if (it == endIt)
583  use = {};
584  else
585  use = (*it).use_begin();
586 }
587 
590 }
591 
593  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
594  replaceUsesWithIf(op->getResults(), shouldReplace);
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // ValueRange
599 
601  : ValueRange(values.data(), values.size()) {}
603  : ValueRange(values.begin().getBase(), values.size()) {}
605  : ValueRange(values.getBase(), values.size()) {}
606 
607 /// See `llvm::detail::indexed_accessor_range_base` for details.
608 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
609  ptrdiff_t index) {
610  if (const auto *value = owner.dyn_cast<const Value *>())
611  return {value + index};
612  if (auto *operand = owner.dyn_cast<OpOperand *>())
613  return {operand + index};
614  return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
615 }
616 /// See `llvm::detail::indexed_accessor_range_base` for details.
617 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
618  if (const auto *value = owner.dyn_cast<const Value *>())
619  return value[index];
620  if (auto *operand = owner.dyn_cast<OpOperand *>())
621  return operand[index].get();
622  return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
623 }
624 
625 //===----------------------------------------------------------------------===//
626 // Operation Equivalency
627 //===----------------------------------------------------------------------===//
628 
630  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
631  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
632  // Hash operations based upon their:
633  // - Operation Name
634  // - Attributes
635  // - Result Types
636  llvm::hash_code hash = llvm::hash_combine(
637  op->getName(), op->getAttrDictionary(), op->getResultTypes());
638 
639  // - Operands
640  ValueRange operands = op->getOperands();
641  SmallVector<Value> operandStorage;
643  operandStorage.append(operands.begin(), operands.end());
644  llvm::sort(operandStorage, [](Value a, Value b) -> bool {
645  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
646  });
647  operands = operandStorage;
648  }
649  for (Value operand : operands)
650  hash = llvm::hash_combine(hash, hashOperands(operand));
651 
652  // - Operands
653  for (Value result : op->getResults())
654  hash = llvm::hash_combine(hash, hashResults(result));
655  return hash;
656 }
657 
659  Region *lhs, Region *rhs,
660  function_ref<LogicalResult(Value, Value)> checkEquivalent,
661  function_ref<void(Value, Value)> markEquivalent,
663  DenseMap<Block *, Block *> blocksMap;
664  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
665  // Check block arguments.
666  if (lBlock.getNumArguments() != rBlock.getNumArguments())
667  return false;
668 
669  // Map the two blocks.
670  auto insertion = blocksMap.insert({&lBlock, &rBlock});
671  if (insertion.first->getSecond() != &rBlock)
672  return false;
673 
674  for (auto argPair :
675  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
676  Value curArg = std::get<0>(argPair);
677  Value otherArg = std::get<1>(argPair);
678  if (curArg.getType() != otherArg.getType())
679  return false;
680  if (!(flags & OperationEquivalence::IgnoreLocations) &&
681  curArg.getLoc() != otherArg.getLoc())
682  return false;
683  // Corresponding bbArgs are equivalent.
684  if (markEquivalent)
685  markEquivalent(curArg, otherArg);
686  }
687 
688  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
689  // Check for op equality (recursively).
690  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
691  markEquivalent, flags))
692  return false;
693  // Check successor mapping.
694  for (auto successorsPair :
695  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
696  Block *curSuccessor = std::get<0>(successorsPair);
697  Block *otherSuccessor = std::get<1>(successorsPair);
698  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
699  if (insertion.first->getSecond() != otherSuccessor)
700  return false;
701  }
702  return true;
703  };
704  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
705  };
706  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
707 }
708 
709 // Value equivalence cache to be used with `isRegionEquivalentTo` and
710 // `isEquivalentTo`.
713  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
714  return success(lhsValue == rhsValue ||
715  equivalentValues.lookup(lhsValue) == rhsValue);
716  }
717  void markEquivalent(Value lhsResult, Value rhsResult) {
718  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
719  // Make sure that the value was not already marked equivalent to some other
720  // value.
721  (void)insertion;
722  assert(insertion.first->second == rhsResult &&
723  "inconsistent OperationEquivalence state");
724  }
725 };
726 
727 /*static*/ bool
730  ValueEquivalenceCache cache;
731  return isRegionEquivalentTo(
732  lhs, rhs,
733  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
734  return cache.checkEquivalent(lhsValue, rhsValue);
735  },
736  [&](Value lhsResult, Value rhsResult) {
737  cache.markEquivalent(lhsResult, rhsResult);
738  },
739  flags);
740 }
741 
743  Operation *lhs, Operation *rhs,
744  function_ref<LogicalResult(Value, Value)> checkEquivalent,
745  function_ref<void(Value, Value)> markEquivalent, Flags flags) {
746  if (lhs == rhs)
747  return true;
748 
749  // 1. Compare the operation properties.
750  if (lhs->getName() != rhs->getName() ||
751  lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
752  lhs->getNumRegions() != rhs->getNumRegions() ||
753  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
754  lhs->getNumOperands() != rhs->getNumOperands() ||
755  lhs->getNumResults() != rhs->getNumResults())
756  return false;
757  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
758  return false;
759 
760  // 2. Compare operands.
761  ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
762  SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
764  auto sortValues = [](ValueRange values) {
765  SmallVector<Value> sortedValues = llvm::to_vector(values);
766  llvm::sort(sortedValues, [](Value a, Value b) {
767  auto aArg = a.dyn_cast<BlockArgument>();
768  auto bArg = b.dyn_cast<BlockArgument>();
769 
770  // Case 1. Both `a` and `b` are `BlockArgument`s.
771  if (aArg && bArg) {
772  if (aArg.getParentBlock() == bArg.getParentBlock())
773  return aArg.getArgNumber() < bArg.getArgNumber();
774  return aArg.getParentBlock() < bArg.getParentBlock();
775  }
776 
777  // Case 2. One of then is a `BlockArgument` and other is not. Treat
778  // `BlockArgument` as lesser.
779  if (aArg && !bArg)
780  return true;
781  if (bArg && !aArg)
782  return false;
783 
784  // Case 3. Both are values.
785  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
786  });
787  return sortedValues;
788  };
789  lhsOperandStorage = sortValues(lhsOperands);
790  lhsOperands = lhsOperandStorage;
791  rhsOperandStorage = sortValues(rhsOperands);
792  rhsOperands = rhsOperandStorage;
793  }
794 
795  for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) {
796  Value curArg = std::get<0>(operandPair);
797  Value otherArg = std::get<1>(operandPair);
798  if (curArg == otherArg)
799  continue;
800  if (curArg.getType() != otherArg.getType())
801  return false;
802  if (failed(checkEquivalent(curArg, otherArg)))
803  return false;
804  }
805 
806  // 3. Compare result types and mark results as equivalent.
807  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
808  Value curArg = std::get<0>(resultPair);
809  Value otherArg = std::get<1>(resultPair);
810  if (curArg.getType() != otherArg.getType())
811  return false;
812  if (markEquivalent)
813  markEquivalent(curArg, otherArg);
814  }
815 
816  // 4. Compare regions.
817  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
818  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
819  &std::get<1>(regionPair), checkEquivalent,
820  markEquivalent, flags))
821  return false;
822 
823  return true;
824 }
825 
827  Operation *rhs,
828  Flags flags) {
829  ValueEquivalenceCache cache;
831  lhs, rhs,
832  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
833  return cache.checkEquivalent(lhsValue, rhsValue);
834  },
835  [&](Value lhsResult, Value rhsResult) {
836  cache.markEquivalent(lhsResult, rhsResult);
837  },
838  flags);
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // OperationFingerPrint
843 //===----------------------------------------------------------------------===//
844 
845 template <typename T>
846 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
847  hasher.update(
848  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
849 }
850 
852  llvm::SHA1 hasher;
853 
854  // Hash each of the operations based upon their mutable bits:
855  topOp->walk([&](Operation *op) {
856  // - Operation pointer
857  addDataToHash(hasher, op);
858  // - Attributes
859  addDataToHash(hasher, op->getAttrDictionary());
860  // - Blocks in Regions
861  for (Region &region : op->getRegions()) {
862  for (Block &block : region) {
863  addDataToHash(hasher, &block);
864  for (BlockArgument arg : block.getArguments())
865  addDataToHash(hasher, arg);
866  }
867  }
868  // - Location
869  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
870  // - Operands
871  for (Value operand : op->getOperands())
872  addDataToHash(hasher, operand);
873  // - Successors
874  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
875  addDataToHash(hasher, op->getSuccessor(i));
876  // - Result types
877  for (Type t : op->getResultTypes())
878  addDataToHash(hasher, t);
879  });
880  hash = hasher.result();
881 }
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:176
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:304
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:316
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:117
BlockArgListType getArguments()
Definition: Block.h:76
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:104
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:195
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments=std::nullopt)
Construct a new mutable range from the given operand, operand start index, and range length.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:120
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:189
This class represents an operand of an operation.
Definition: Value.h:255
This is a value defined by a result of an operation.
Definition: Value.h:442
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:82
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:592
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:222
void setOperand(unsigned idx, Value value)
Definition: Operation.h:330
Block * getSuccessor(unsigned index)
Definition: Operation.h:572
unsigned getNumSuccessors()
Definition: Operation.h:570
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:339
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:640
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:537
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
unsigned getNumOperands()
Definition: Operation.h:325
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:457
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:540
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:421
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
result_type_range getResultTypes()
Definition: Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:203
SuccessorRange getSuccessors()
Definition: Operation.h:567
result_range getResults()
Definition: Operation.h:394
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:334
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:231
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:306
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:287
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:270
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:251
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:375
ValueRange(Arg &&arg)
Definition: ValueRange.h:383
An iterator over the users of an IRObject.
Definition: UseDefLists.h:291
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
U dyn_cast() const
Definition: Value.h:103
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:231
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
This class provides the implementation for an operation result.
Definition: Value.h:353
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
DenseMap< Value, Value > equivalentValues
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags)
Compare two regions (including their subregions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None)
Compare two operations (including their regions) and return if they are equivalent.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Region * addRegion()
Create a region that should be attached to the operation.