MLIR  21.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
28 
30  assign(attributes.begin(), attributes.end());
31 }
32 
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34  : NamedAttrList(attributes ? attributes.getValue()
35  : ArrayRef<NamedAttribute>()) {
36  dictionarySorted.setPointerAndInt(attributes, true);
37 }
38 
40  assign(inStart, inEnd);
41 }
42 
44 
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46  std::optional<NamedAttribute> duplicate =
47  DictionaryAttr::findDuplicate(attrs, isSorted());
48  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49  // state.
50  if (!isSorted())
51  dictionarySorted.setPointerAndInt(nullptr, true);
52  return duplicate;
53 }
54 
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56  if (!isSorted()) {
57  DictionaryAttr::sortInPlace(attrs);
58  dictionarySorted.setPointerAndInt(nullptr, true);
59  }
60  if (!dictionarySorted.getPointer())
61  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62  return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63 }
64 
65 /// Replaces the attributes with new list of attributes.
67  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
68  dictionarySorted.setPointerAndInt(nullptr, true);
69 }
70 
72  if (isSorted())
73  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
74  dictionarySorted.setPointer(nullptr);
75  attrs.push_back(newAttribute);
76 }
77 
78 /// Return the specified attribute if present, null otherwise.
79 Attribute NamedAttrList::get(StringRef name) const {
80  auto it = findAttr(*this, name);
81  return it.second ? it.first->getValue() : Attribute();
82 }
83 Attribute NamedAttrList::get(StringAttr name) const {
84  auto it = findAttr(*this, name);
85  return it.second ? it.first->getValue() : Attribute();
86 }
87 
88 /// Return the specified named attribute if present, std::nullopt otherwise.
89 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
90  auto it = findAttr(*this, name);
91  return it.second ? *it.first : std::optional<NamedAttribute>();
92 }
93 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
94  auto it = findAttr(*this, name);
95  return it.second ? *it.first : std::optional<NamedAttribute>();
96 }
97 
98 /// If the an attribute exists with the specified name, change it to the new
99 /// value. Otherwise, add a new attribute with the specified name/value.
100 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
101  assert(value && "attributes may never be null");
102 
103  // Look for an existing attribute with the given name, and set its value
104  // in-place. Return the previous value of the attribute, if there was one.
105  auto it = findAttr(*this, name);
106  if (it.second) {
107  // Update the existing attribute by swapping out the old value for the new
108  // value. Return the old value.
109  Attribute oldValue = it.first->getValue();
110  if (it.first->getValue() != value) {
111  it.first->setValue(value);
112 
113  // If the attributes have changed, the dictionary is invalidated.
114  dictionarySorted.setPointer(nullptr);
115  }
116  return oldValue;
117  }
118  // Perform a string lookup to insert the new attribute into its sorted
119  // position.
120  if (isSorted())
121  it = findAttr(*this, name.strref());
122  attrs.insert(it.first, {name, value});
123  // Invalidate the dictionary. Return null as there was no previous value.
124  dictionarySorted.setPointer(nullptr);
125  return Attribute();
126 }
127 
128 Attribute NamedAttrList::set(StringRef name, Attribute value) {
129  assert(value && "attributes may never be null");
130  return set(mlir::StringAttr::get(value.getContext(), name), value);
131 }
132 
133 Attribute
134 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
135  // Erasing does not affect the sorted property.
136  Attribute attr = it->getValue();
137  attrs.erase(it);
138  dictionarySorted.setPointer(nullptr);
139  return attr;
140 }
141 
142 Attribute NamedAttrList::erase(StringAttr name) {
143  auto it = findAttr(*this, name);
144  return it.second ? eraseImpl(it.first) : Attribute();
145 }
146 
148  auto it = findAttr(*this, name);
149  return it.second ? eraseImpl(it.first) : Attribute();
150 }
151 
154  assign(rhs.begin(), rhs.end());
155  return *this;
156 }
157 
158 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
159 
160 //===----------------------------------------------------------------------===//
161 // OperationState
162 //===----------------------------------------------------------------------===//
163 
164 OperationState::OperationState(Location location, StringRef name)
165  : location(location), name(name, location->getContext()) {}
166 
168  : location(location), name(name) {}
169 
171  ValueRange operands, TypeRange types,
172  ArrayRef<NamedAttribute> attributes,
173  BlockRange successors,
174  MutableArrayRef<std::unique_ptr<Region>> regions)
175  : location(location), name(name),
176  operands(operands.begin(), operands.end()),
177  types(types.begin(), types.end()),
178  attributes(attributes.begin(), attributes.end()),
179  successors(successors.begin(), successors.end()) {
180  for (std::unique_ptr<Region> &r : regions)
181  this->regions.push_back(std::move(r));
182 }
183 OperationState::OperationState(Location location, StringRef name,
184  ValueRange operands, TypeRange types,
185  ArrayRef<NamedAttribute> attributes,
186  BlockRange successors,
187  MutableArrayRef<std::unique_ptr<Region>> regions)
188  : OperationState(location, OperationName(name, location.getContext()),
189  operands, types, attributes, successors, regions) {}
190 
192  if (properties)
193  propertiesDeleter(properties);
194 }
195 
198  if (LLVM_UNLIKELY(propertiesAttr)) {
199  assert(!properties);
201  }
202  if (properties)
203  propertiesSetter(op->getPropertiesStorage(), properties);
204  return success();
205 }
206 
208  operands.append(newOperands.begin(), newOperands.end());
209 }
210 
212  successors.append(newSuccessors.begin(), newSuccessors.end());
213 }
214 
216  regions.emplace_back(new Region);
217  return regions.back().get();
218 }
219 
220 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
221  regions.push_back(std::move(region));
222 }
223 
225  MutableArrayRef<std::unique_ptr<Region>> regions) {
226  for (std::unique_ptr<Region> &region : regions)
227  addRegion(std::move(region));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // OperandStorage
232 //===----------------------------------------------------------------------===//
233 
235  OpOperand *trailingOperands,
236  ValueRange values)
237  : isStorageDynamic(false), operandStorage(trailingOperands) {
238  numOperands = capacity = values.size();
239  for (unsigned i = 0; i < numOperands; ++i)
240  new (&operandStorage[i]) OpOperand(owner, values[i]);
241 }
242 
244  for (auto &operand : getOperands())
245  operand.~OpOperand();
246 
247  // If the storage is dynamic, deallocate it.
248  if (isStorageDynamic)
249  free(operandStorage);
250 }
251 
252 /// Replace the operands contained in the storage with the ones provided in
253 /// 'values'.
255  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
256  for (unsigned i = 0, e = values.size(); i != e; ++i)
257  storageOperands[i].set(values[i]);
258 }
259 
260 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
261 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
262 /// than the range pointed to by 'start'+'length'.
263 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
264  unsigned length, ValueRange operands) {
265  // If the new size is the same, we can update inplace.
266  unsigned newSize = operands.size();
267  if (newSize == length) {
268  MutableArrayRef<OpOperand> storageOperands = getOperands();
269  for (unsigned i = 0, e = length; i != e; ++i)
270  storageOperands[start + i].set(operands[i]);
271  return;
272  }
273  // If the new size is greater, remove the extra operands and set the rest
274  // inplace.
275  if (newSize < length) {
276  eraseOperands(start + operands.size(), length - newSize);
277  setOperands(owner, start, newSize, operands);
278  return;
279  }
280  // Otherwise, the new size is greater so we need to grow the storage.
281  auto storageOperands = resize(owner, size() + (newSize - length));
282 
283  // Shift operands to the right to make space for the new operands.
284  unsigned rotateSize = storageOperands.size() - (start + length);
285  auto rbegin = storageOperands.rbegin();
286  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
287 
288  // Update the operands inplace.
289  for (unsigned i = 0, e = operands.size(); i != e; ++i)
290  storageOperands[start + i].set(operands[i]);
291 }
292 
293 /// Erase an operand held by the storage.
294 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
295  MutableArrayRef<OpOperand> operands = getOperands();
296  assert((start + length) <= operands.size());
297  numOperands -= length;
298 
299  // Shift all operands down if the operand to remove is not at the end.
300  if (start != numOperands) {
301  auto *indexIt = std::next(operands.begin(), start);
302  std::rotate(indexIt, std::next(indexIt, length), operands.end());
303  }
304  for (unsigned i = 0; i != length; ++i)
305  operands[numOperands + i].~OpOperand();
306 }
307 
308 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
309  MutableArrayRef<OpOperand> operands = getOperands();
310  assert(eraseIndices.size() == operands.size());
311 
312  // Check that at least one operand is erased.
313  int firstErasedIndice = eraseIndices.find_first();
314  if (firstErasedIndice == -1)
315  return;
316 
317  // Shift all of the removed operands to the end, and destroy them.
318  numOperands = firstErasedIndice;
319  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
320  if (!eraseIndices.test(i))
321  operands[numOperands++] = std::move(operands[i]);
322  for (OpOperand &operand : operands.drop_front(numOperands))
323  operand.~OpOperand();
324 }
325 
326 /// Resize the storage to the given size. Returns the array containing the new
327 /// operands.
328 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
329  unsigned newSize) {
330  // If the number of operands is less than or equal to the current amount, we
331  // can just update in place.
332  MutableArrayRef<OpOperand> origOperands = getOperands();
333  if (newSize <= numOperands) {
334  // If the number of new size is less than the current, remove any extra
335  // operands.
336  for (unsigned i = newSize; i != numOperands; ++i)
337  origOperands[i].~OpOperand();
338  numOperands = newSize;
339  return origOperands.take_front(newSize);
340  }
341 
342  // If the new size is within the original inline capacity, grow inplace.
343  if (newSize <= capacity) {
344  OpOperand *opBegin = origOperands.data();
345  for (unsigned e = newSize; numOperands != e; ++numOperands)
346  new (&opBegin[numOperands]) OpOperand(owner);
347  return MutableArrayRef<OpOperand>(opBegin, newSize);
348  }
349 
350  // Otherwise, we need to allocate a new storage.
351  unsigned newCapacity =
352  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
353  OpOperand *newOperandStorage =
354  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
355 
356  // Move the current operands to the new storage.
357  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
358  std::uninitialized_move(origOperands.begin(), origOperands.end(),
359  newOperands.begin());
360 
361  // Destroy the original operands.
362  for (auto &operand : origOperands)
363  operand.~OpOperand();
364 
365  // Initialize any new operands.
366  for (unsigned e = newSize; numOperands != e; ++numOperands)
367  new (&newOperands[numOperands]) OpOperand(owner);
368 
369  // If the current storage is dynamic, free it.
370  if (isStorageDynamic)
371  free(operandStorage);
372 
373  // Update the storage representation to use the new dynamic storage.
374  operandStorage = newOperandStorage;
375  capacity = newCapacity;
376  isStorageDynamic = true;
377  return newOperands;
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // Operation Value-Iterators
382 //===----------------------------------------------------------------------===//
383 
384 //===----------------------------------------------------------------------===//
385 // OperandRange
386 //===----------------------------------------------------------------------===//
387 
389  assert(!empty() && "range must not be empty");
390  return base->getOperandNumber();
391 }
392 
394  return OperandRangeRange(*this, segmentSizes);
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // OperandRangeRange
399 //===----------------------------------------------------------------------===//
400 
402  Attribute operandSegments)
403  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
404  llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
405 }
406 
408  const OwnerT &owner = getBase();
409  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
410  return OperandRange(owner.first,
411  std::accumulate(sizeData.begin(), sizeData.end(), 0));
412 }
413 
414 OperandRange OperandRangeRange::dereference(const OwnerT &object,
415  ptrdiff_t index) {
416  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
417  uint32_t startIndex =
418  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
419  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // MutableOperandRange
424 //===----------------------------------------------------------------------===//
425 
426 /// Construct a new mutable range from the given operand, operand start index,
427 /// and range length.
429  Operation *owner, unsigned start, unsigned length,
430  ArrayRef<OperandSegment> operandSegments)
431  : owner(owner), start(start), length(length),
432  operandSegments(operandSegments) {
433  assert((start + length) <= owner->getNumOperands() && "invalid range");
434 }
436  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
437 
438 /// Construct a new mutable range for the given OpOperand.
440  : MutableOperandRange(opOperand.getOwner(),
441  /*start=*/opOperand.getOperandNumber(),
442  /*length=*/1) {}
443 
444 /// Slice this range into a sub range, with the additional operand segment.
446 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
447  std::optional<OperandSegment> segment) const {
448  assert((subStart + subLen) <= length && "invalid sub-range");
449  MutableOperandRange subSlice(owner, start + subStart, subLen,
450  operandSegments);
451  if (segment)
452  subSlice.operandSegments.push_back(*segment);
453  return subSlice;
454 }
455 
456 /// Append the given values to the range.
458  if (values.empty())
459  return;
460  owner->insertOperands(start + length, values);
461  updateLength(length + values.size());
462 }
463 
464 /// Assign this range to the given values.
466  owner->setOperands(start, length, values);
467  if (length != values.size())
468  updateLength(/*newLength=*/values.size());
469 }
470 
471 /// Assign the range to the given value.
473  if (length == 1) {
474  owner->setOperand(start, value);
475  } else {
476  owner->setOperands(start, length, value);
477  updateLength(/*newLength=*/1);
478  }
479 }
480 
481 /// Erase the operands within the given sub-range.
482 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
483  assert((subStart + subLen) <= length && "invalid sub-range");
484  if (length == 0)
485  return;
486  owner->eraseOperands(start + subStart, subLen);
487  updateLength(length - subLen);
488 }
489 
490 /// Clear this range and erase all of the operands.
492  if (length != 0) {
493  owner->eraseOperands(start, length);
494  updateLength(/*newLength=*/0);
495  }
496 }
497 
498 /// Explicit conversion to an OperandRange.
500  return owner->getOperands().slice(start, length);
501 }
502 
503 /// Allow implicit conversion to an OperandRange.
504 MutableOperandRange::operator OperandRange() const {
505  return getAsOperandRange();
506 }
507 
508 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
509  return owner->getOpOperands().slice(start, length);
510 }
511 
514  return MutableOperandRangeRange(*this, segmentSizes);
515 }
516 
517 /// Update the length of this range to the one provided.
518 void MutableOperandRange::updateLength(unsigned newLength) {
519  int32_t diff = int32_t(newLength) - int32_t(length);
520  length = newLength;
521 
522  // Update any of the provided segment attributes.
523  for (OperandSegment &segment : operandSegments) {
524  auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
525  SmallVector<int32_t, 8> segments(attr.asArrayRef());
526  segments[segment.first] += diff;
527  segment.second.setValue(
528  DenseI32ArrayAttr::get(attr.getContext(), segments));
529  owner->setAttr(segment.second.getName(), segment.second.getValue());
530  }
531 }
532 
534  assert(index < length && "index is out of bounds");
535  return owner->getOpOperand(start + index);
536 }
537 
539  return owner->getOpOperands().slice(start, length).begin();
540 }
541 
543  return owner->getOpOperands().slice(start, length).end();
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // MutableOperandRangeRange
548 //===----------------------------------------------------------------------===//
549 
551  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
553  OwnerT(operands, operandSegmentAttr), 0,
554  llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
555 }
556 
558  return getBase().first;
559 }
560 
561 MutableOperandRangeRange::operator OperandRangeRange() const {
562  return OperandRangeRange(getBase().first, getBase().second.getValue());
563 }
564 
565 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
566  ptrdiff_t index) {
567  ArrayRef<int32_t> sizeData =
568  llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
569  uint32_t startIndex =
570  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
571  return object.first.slice(
572  startIndex, *(sizeData.begin() + index),
573  MutableOperandRange::OperandSegment(index, object.second));
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // ResultRange
578 //===----------------------------------------------------------------------===//
579 
581  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
582  1) {}
583 
585  return {use_begin(), use_end()};
586 }
588  return use_iterator(*this);
589 }
591  return use_iterator(*this, /*end=*/true);
592 }
594  return {user_begin(), user_end()};
595 }
597  return user_iterator(use_begin());
598 }
600  return user_iterator(use_end());
601 }
602 
604  : it(end ? results.end() : results.begin()), endIt(results.end()) {
605  // Only initialize current use if there are results/can be uses.
606  if (it != endIt)
607  skipOverResultsWithNoUsers();
608 }
609 
611  // We increment over uses, if we reach the last use then move to next
612  // result.
613  if (use != (*it).use_end())
614  ++use;
615  if (use == (*it).use_end()) {
616  ++it;
617  skipOverResultsWithNoUsers();
618  }
619  return *this;
620 }
621 
622 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
623  while (it != endIt && (*it).use_empty())
624  ++it;
625 
626  // If we are at the last result, then set use to first use of
627  // first result (sentinel value used for end).
628  if (it == endIt)
629  use = {};
630  else
631  use = (*it).use_begin();
632 }
633 
636 }
637 
639  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
640  replaceUsesWithIf(op->getResults(), shouldReplace);
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // ValueRange
645 //===----------------------------------------------------------------------===//
646 
648  : ValueRange(values.data(), values.size()) {}
650  : ValueRange(values.begin().getBase(), values.size()) {}
652  : ValueRange(values.getBase(), values.size()) {}
653 
654 /// See `llvm::detail::indexed_accessor_range_base` for details.
655 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
656  ptrdiff_t index) {
657  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
658  return {value + index};
659  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
660  return {operand + index};
661  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
662 }
663 /// See `llvm::detail::indexed_accessor_range_base` for details.
664 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
665  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
666  return value[index];
667  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
668  return operand[index].get();
669  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // Operation Equivalency
674 //===----------------------------------------------------------------------===//
675 
677  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
678  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
679  // Hash operations based upon their:
680  // - Operation Name
681  // - Attributes
682  // - Result Types
683  llvm::hash_code hash =
684  llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(),
685  op->getResultTypes(), op->hashProperties());
686 
687  // - Location if required
688  if (!(flags & Flags::IgnoreLocations))
689  hash = llvm::hash_combine(hash, op->getLoc());
690 
691  // - Operands
693  op->getNumOperands() > 0) {
694  size_t operandHash = hashOperands(op->getOperand(0));
695  for (auto operand : op->getOperands().drop_front())
696  operandHash += hashOperands(operand);
697  hash = llvm::hash_combine(hash, operandHash);
698  } else {
699  for (Value operand : op->getOperands())
700  hash = llvm::hash_combine(hash, hashOperands(operand));
701  }
702 
703  // - Results
704  for (Value result : op->getResults())
705  hash = llvm::hash_combine(hash, hashResults(result));
706  return hash;
707 }
708 
710  Region *lhs, Region *rhs,
711  function_ref<LogicalResult(Value, Value)> checkEquivalent,
712  function_ref<void(Value, Value)> markEquivalent,
714  function_ref<LogicalResult(ValueRange, ValueRange)>
715  checkCommutativeEquivalent) {
716  DenseMap<Block *, Block *> blocksMap;
717  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
718  // Check block arguments.
719  if (lBlock.getNumArguments() != rBlock.getNumArguments())
720  return false;
721 
722  // Map the two blocks.
723  auto insertion = blocksMap.insert({&lBlock, &rBlock});
724  if (insertion.first->getSecond() != &rBlock)
725  return false;
726 
727  for (auto argPair :
728  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
729  Value curArg = std::get<0>(argPair);
730  Value otherArg = std::get<1>(argPair);
731  if (curArg.getType() != otherArg.getType())
732  return false;
733  if (!(flags & OperationEquivalence::IgnoreLocations) &&
734  curArg.getLoc() != otherArg.getLoc())
735  return false;
736  // Corresponding bbArgs are equivalent.
737  if (markEquivalent)
738  markEquivalent(curArg, otherArg);
739  }
740 
741  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
742  // Check for op equality (recursively).
743  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
744  markEquivalent, flags,
745  checkCommutativeEquivalent))
746  return false;
747  // Check successor mapping.
748  for (auto successorsPair :
749  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
750  Block *curSuccessor = std::get<0>(successorsPair);
751  Block *otherSuccessor = std::get<1>(successorsPair);
752  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
753  if (insertion.first->getSecond() != otherSuccessor)
754  return false;
755  }
756  return true;
757  };
758  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
759  };
760  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
761 }
762 
763 // Value equivalence cache to be used with `isRegionEquivalentTo` and
764 // `isEquivalentTo`.
767  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
768  return success(lhsValue == rhsValue ||
769  equivalentValues.lookup(lhsValue) == rhsValue);
770  }
771  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
772  ValueRange rhsRange) {
773  // Handle simple case where sizes mismatch.
774  if (lhsRange.size() != rhsRange.size())
775  return failure();
776 
777  // Handle where operands in order are equivalent.
778  auto lhsIt = lhsRange.begin();
779  auto rhsIt = rhsRange.begin();
780  for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
781  if (failed(checkEquivalent(*lhsIt, *rhsIt)))
782  break;
783  }
784  if (lhsIt == lhsRange.end())
785  return success();
786 
787  // Handle another simple case where operands are just a permutation.
788  // Note: This is not sufficient, this handles simple cases relatively
789  // cheaply.
790  auto sortValues = [](ValueRange values) {
791  SmallVector<Value> sortedValues = llvm::to_vector(values);
792  llvm::sort(sortedValues, [](Value a, Value b) {
793  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
794  });
795  return sortedValues;
796  };
797  auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
798  auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
799  return success(lhsSorted == rhsSorted);
800  }
801  void markEquivalent(Value lhsResult, Value rhsResult) {
802  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
803  // Make sure that the value was not already marked equivalent to some other
804  // value.
805  (void)insertion;
806  assert(insertion.first->second == rhsResult &&
807  "inconsistent OperationEquivalence state");
808  }
809 };
810 
811 /*static*/ bool
814  ValueEquivalenceCache cache;
815  return isRegionEquivalentTo(
816  lhs, rhs,
817  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
818  return cache.checkEquivalent(lhsValue, rhsValue);
819  },
820  [&](Value lhsResult, Value rhsResult) {
821  cache.markEquivalent(lhsResult, rhsResult);
822  },
823  flags,
824  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
825  return cache.checkCommutativeEquivalent(lhs, rhs);
826  });
827 }
828 
830  Operation *lhs, Operation *rhs,
831  function_ref<LogicalResult(Value, Value)> checkEquivalent,
832  function_ref<void(Value, Value)> markEquivalent, Flags flags,
833  function_ref<LogicalResult(ValueRange, ValueRange)>
834  checkCommutativeEquivalent) {
835  if (lhs == rhs)
836  return true;
837 
838  // 1. Compare the operation properties.
839  if (lhs->getName() != rhs->getName() ||
840  lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
841  lhs->getNumRegions() != rhs->getNumRegions() ||
842  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
843  lhs->getNumOperands() != rhs->getNumOperands() ||
844  lhs->getNumResults() != rhs->getNumResults() ||
846  rhs->getPropertiesStorage()))
847  return false;
848  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
849  return false;
850 
851  // 2. Compare operands.
852  if (checkCommutativeEquivalent &&
854  auto lhsRange = lhs->getOperands();
855  auto rhsRange = rhs->getOperands();
856  if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
857  return false;
858  } else {
859  // Check pair wise for equivalence.
860  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
861  Value curArg = std::get<0>(operandPair);
862  Value otherArg = std::get<1>(operandPair);
863  if (curArg == otherArg)
864  continue;
865  if (curArg.getType() != otherArg.getType())
866  return false;
867  if (failed(checkEquivalent(curArg, otherArg)))
868  return false;
869  }
870  }
871 
872  // 3. Compare result types and mark results as equivalent.
873  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
874  Value curArg = std::get<0>(resultPair);
875  Value otherArg = std::get<1>(resultPair);
876  if (curArg.getType() != otherArg.getType())
877  return false;
878  if (markEquivalent)
879  markEquivalent(curArg, otherArg);
880  }
881 
882  // 4. Compare regions.
883  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
884  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
885  &std::get<1>(regionPair), checkEquivalent,
886  markEquivalent, flags))
887  return false;
888 
889  return true;
890 }
891 
893  Operation *rhs,
894  Flags flags) {
895  ValueEquivalenceCache cache;
897  lhs, rhs,
898  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
899  return cache.checkEquivalent(lhsValue, rhsValue);
900  },
901  [&](Value lhsResult, Value rhsResult) {
902  cache.markEquivalent(lhsResult, rhsResult);
903  },
904  flags,
905  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
906  return cache.checkCommutativeEquivalent(lhs, rhs);
907  });
908 }
909 
910 //===----------------------------------------------------------------------===//
911 // OperationFingerPrint
912 //===----------------------------------------------------------------------===//
913 
914 template <typename T>
915 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
916  hasher.update(
917  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
918 }
919 
921  bool includeNested) {
922  llvm::SHA1 hasher;
923 
924  // Helper function that hashes an operation based on its mutable bits:
925  auto addOperationToHash = [&](Operation *op) {
926  // - Operation pointer
927  addDataToHash(hasher, op);
928  // - Parent operation pointer (to take into account the nesting structure)
929  if (op != topOp)
930  addDataToHash(hasher, op->getParentOp());
931  // - Attributes
932  addDataToHash(hasher, op->getRawDictionaryAttrs());
933  // - Properties
934  addDataToHash(hasher, op->hashProperties());
935  // - Blocks in Regions
936  for (Region &region : op->getRegions()) {
937  for (Block &block : region) {
938  addDataToHash(hasher, &block);
939  for (BlockArgument arg : block.getArguments())
940  addDataToHash(hasher, arg);
941  }
942  }
943  // - Location
944  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
945  // - Operands
946  for (Value operand : op->getOperands())
947  addDataToHash(hasher, operand);
948  // - Successors
949  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
950  addDataToHash(hasher, op->getSuccessor(i));
951  // - Result types
952  for (Type t : op->getResultTypes())
953  addDataToHash(hasher, t);
954  };
955 
956  if (includeNested)
957  topOp->walk(addOperationToHash);
958  else
959  addOperationToHash(topOp);
960 
961  hash = hasher.result();
962 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:295
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
BlockArgListType getArguments()
Definition: Block.h:87
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:93
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:210
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments=std::nullopt)
Construct a new mutable range from the given operand, operand start index, and range length.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:123
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class represents an operand of an operation.
Definition: Value.h:243
This is a value defined by a result of an operation.
Definition: Value.h:433
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
Block * getSuccessor(unsigned index)
Definition: Operation.h:709
unsigned getNumSuccessors()
Definition: Operation.h:707
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:360
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:509
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
Definition: Operation.cpp:355
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
SuccessorRange getSuccessors()
Definition: Operation.h:704
result_range getResults()
Definition: Operation.h:415
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
Definition: Operation.cpp:370
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:350
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:322
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:303
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:286
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:267
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:392
ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
Definition: ValueRange.h:400
An iterator over the users of an IRObject.
Definition: UseDefLists.h:344
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:219
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const