MLIR  21.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/Support/SHA1.h"
19 #include <numeric>
20 #include <optional>
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // NamedAttrList
26 //===----------------------------------------------------------------------===//
27 
29  assign(attributes.begin(), attributes.end());
30 }
31 
32 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
33  : NamedAttrList(attributes ? attributes.getValue()
34  : ArrayRef<NamedAttribute>()) {
35  dictionarySorted.setPointerAndInt(attributes, true);
36 }
37 
39  assign(inStart, inEnd);
40 }
41 
43 
44 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
45  std::optional<NamedAttribute> duplicate =
46  DictionaryAttr::findDuplicate(attrs, isSorted());
47  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
48  // state.
49  if (!isSorted())
50  dictionarySorted.setPointerAndInt(nullptr, true);
51  return duplicate;
52 }
53 
54 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
55  if (!isSorted()) {
56  DictionaryAttr::sortInPlace(attrs);
57  dictionarySorted.setPointerAndInt(nullptr, true);
58  }
59  if (!dictionarySorted.getPointer())
60  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
61  return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
62 }
63 
64 /// Replaces the attributes with new list of attributes.
66  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
67  dictionarySorted.setPointerAndInt(nullptr, true);
68 }
69 
71  if (isSorted())
72  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
73  dictionarySorted.setPointer(nullptr);
74  attrs.push_back(newAttribute);
75 }
76 
77 /// Return the specified attribute if present, null otherwise.
78 Attribute NamedAttrList::get(StringRef name) const {
79  auto it = findAttr(*this, name);
80  return it.second ? it.first->getValue() : Attribute();
81 }
82 Attribute NamedAttrList::get(StringAttr name) const {
83  auto it = findAttr(*this, name);
84  return it.second ? it.first->getValue() : Attribute();
85 }
86 
87 /// Return the specified named attribute if present, std::nullopt otherwise.
88 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
89  auto it = findAttr(*this, name);
90  return it.second ? *it.first : std::optional<NamedAttribute>();
91 }
92 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
93  auto it = findAttr(*this, name);
94  return it.second ? *it.first : std::optional<NamedAttribute>();
95 }
96 
97 /// If the an attribute exists with the specified name, change it to the new
98 /// value. Otherwise, add a new attribute with the specified name/value.
99 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
100  assert(value && "attributes may never be null");
101 
102  // Look for an existing attribute with the given name, and set its value
103  // in-place. Return the previous value of the attribute, if there was one.
104  auto it = findAttr(*this, name);
105  if (it.second) {
106  // Update the existing attribute by swapping out the old value for the new
107  // value. Return the old value.
108  Attribute oldValue = it.first->getValue();
109  if (it.first->getValue() != value) {
110  it.first->setValue(value);
111 
112  // If the attributes have changed, the dictionary is invalidated.
113  dictionarySorted.setPointer(nullptr);
114  }
115  return oldValue;
116  }
117  // Perform a string lookup to insert the new attribute into its sorted
118  // position.
119  if (isSorted())
120  it = findAttr(*this, name.strref());
121  attrs.insert(it.first, {name, value});
122  // Invalidate the dictionary. Return null as there was no previous value.
123  dictionarySorted.setPointer(nullptr);
124  return Attribute();
125 }
126 
127 Attribute NamedAttrList::set(StringRef name, Attribute value) {
128  assert(value && "attributes may never be null");
129  return set(mlir::StringAttr::get(value.getContext(), name), value);
130 }
131 
132 Attribute
133 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
134  // Erasing does not affect the sorted property.
135  Attribute attr = it->getValue();
136  attrs.erase(it);
137  dictionarySorted.setPointer(nullptr);
138  return attr;
139 }
140 
141 Attribute NamedAttrList::erase(StringAttr name) {
142  auto it = findAttr(*this, name);
143  return it.second ? eraseImpl(it.first) : Attribute();
144 }
145 
147  auto it = findAttr(*this, name);
148  return it.second ? eraseImpl(it.first) : Attribute();
149 }
150 
153  assign(rhs.begin(), rhs.end());
154  return *this;
155 }
156 
157 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
158 
159 //===----------------------------------------------------------------------===//
160 // OperationState
161 //===----------------------------------------------------------------------===//
162 
163 OperationState::OperationState(Location location, StringRef name)
164  : location(location), name(name, location->getContext()) {}
165 
167  : location(location), name(name) {}
168 
170  ValueRange operands, TypeRange types,
171  ArrayRef<NamedAttribute> attributes,
172  BlockRange successors,
173  MutableArrayRef<std::unique_ptr<Region>> regions)
174  : location(location), name(name),
175  operands(operands.begin(), operands.end()),
176  types(types.begin(), types.end()),
177  attributes(attributes.begin(), attributes.end()),
178  successors(successors.begin(), successors.end()) {
179  for (std::unique_ptr<Region> &r : regions)
180  this->regions.push_back(std::move(r));
181 }
182 OperationState::OperationState(Location location, StringRef name,
183  ValueRange operands, TypeRange types,
184  ArrayRef<NamedAttribute> attributes,
185  BlockRange successors,
186  MutableArrayRef<std::unique_ptr<Region>> regions)
187  : OperationState(location, OperationName(name, location.getContext()),
188  operands, types, attributes, successors, regions) {}
189 
191  if (properties)
192  propertiesDeleter(properties);
193 }
194 
197  if (LLVM_UNLIKELY(propertiesAttr)) {
198  assert(!properties);
200  }
201  if (properties)
202  propertiesSetter(op->getPropertiesStorage(), properties);
203  return success();
204 }
205 
207  operands.append(newOperands.begin(), newOperands.end());
208 }
209 
211  successors.append(newSuccessors.begin(), newSuccessors.end());
212 }
213 
215  regions.emplace_back(new Region);
216  return regions.back().get();
217 }
218 
219 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
220  regions.push_back(std::move(region));
221 }
222 
224  MutableArrayRef<std::unique_ptr<Region>> regions) {
225  for (std::unique_ptr<Region> &region : regions)
226  addRegion(std::move(region));
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // OperandStorage
231 //===----------------------------------------------------------------------===//
232 
234  OpOperand *trailingOperands,
235  ValueRange values)
236  : isStorageDynamic(false), operandStorage(trailingOperands) {
237  numOperands = capacity = values.size();
238  for (unsigned i = 0; i < numOperands; ++i)
239  new (&operandStorage[i]) OpOperand(owner, values[i]);
240 }
241 
243  for (auto &operand : getOperands())
244  operand.~OpOperand();
245 
246  // If the storage is dynamic, deallocate it.
247  if (isStorageDynamic)
248  free(operandStorage);
249 }
250 
251 /// Replace the operands contained in the storage with the ones provided in
252 /// 'values'.
254  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
255  for (unsigned i = 0, e = values.size(); i != e; ++i)
256  storageOperands[i].set(values[i]);
257 }
258 
259 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
260 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
261 /// than the range pointed to by 'start'+'length'.
262 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
263  unsigned length, ValueRange operands) {
264  // If the new size is the same, we can update inplace.
265  unsigned newSize = operands.size();
266  if (newSize == length) {
267  MutableArrayRef<OpOperand> storageOperands = getOperands();
268  for (unsigned i = 0, e = length; i != e; ++i)
269  storageOperands[start + i].set(operands[i]);
270  return;
271  }
272  // If the new size is greater, remove the extra operands and set the rest
273  // inplace.
274  if (newSize < length) {
275  eraseOperands(start + operands.size(), length - newSize);
276  setOperands(owner, start, newSize, operands);
277  return;
278  }
279  // Otherwise, the new size is greater so we need to grow the storage.
280  auto storageOperands = resize(owner, size() + (newSize - length));
281 
282  // Shift operands to the right to make space for the new operands.
283  unsigned rotateSize = storageOperands.size() - (start + length);
284  auto rbegin = storageOperands.rbegin();
285  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
286 
287  // Update the operands inplace.
288  for (unsigned i = 0, e = operands.size(); i != e; ++i)
289  storageOperands[start + i].set(operands[i]);
290 }
291 
292 /// Erase an operand held by the storage.
293 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
294  MutableArrayRef<OpOperand> operands = getOperands();
295  assert((start + length) <= operands.size());
296  numOperands -= length;
297 
298  // Shift all operands down if the operand to remove is not at the end.
299  if (start != numOperands) {
300  auto *indexIt = std::next(operands.begin(), start);
301  std::rotate(indexIt, std::next(indexIt, length), operands.end());
302  }
303  for (unsigned i = 0; i != length; ++i)
304  operands[numOperands + i].~OpOperand();
305 }
306 
307 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
308  MutableArrayRef<OpOperand> operands = getOperands();
309  assert(eraseIndices.size() == operands.size());
310 
311  // Check that at least one operand is erased.
312  int firstErasedIndice = eraseIndices.find_first();
313  if (firstErasedIndice == -1)
314  return;
315 
316  // Shift all of the removed operands to the end, and destroy them.
317  numOperands = firstErasedIndice;
318  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
319  if (!eraseIndices.test(i))
320  operands[numOperands++] = std::move(operands[i]);
321  for (OpOperand &operand : operands.drop_front(numOperands))
322  operand.~OpOperand();
323 }
324 
325 /// Resize the storage to the given size. Returns the array containing the new
326 /// operands.
327 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
328  unsigned newSize) {
329  // If the number of operands is less than or equal to the current amount, we
330  // can just update in place.
331  MutableArrayRef<OpOperand> origOperands = getOperands();
332  if (newSize <= numOperands) {
333  // If the number of new size is less than the current, remove any extra
334  // operands.
335  for (unsigned i = newSize; i != numOperands; ++i)
336  origOperands[i].~OpOperand();
337  numOperands = newSize;
338  return origOperands.take_front(newSize);
339  }
340 
341  // If the new size is within the original inline capacity, grow inplace.
342  if (newSize <= capacity) {
343  OpOperand *opBegin = origOperands.data();
344  for (unsigned e = newSize; numOperands != e; ++numOperands)
345  new (&opBegin[numOperands]) OpOperand(owner);
346  return MutableArrayRef<OpOperand>(opBegin, newSize);
347  }
348 
349  // Otherwise, we need to allocate a new storage.
350  unsigned newCapacity =
351  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
352  OpOperand *newOperandStorage =
353  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
354 
355  // Move the current operands to the new storage.
356  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
357  std::uninitialized_move(origOperands.begin(), origOperands.end(),
358  newOperands.begin());
359 
360  // Destroy the original operands.
361  for (auto &operand : origOperands)
362  operand.~OpOperand();
363 
364  // Initialize any new operands.
365  for (unsigned e = newSize; numOperands != e; ++numOperands)
366  new (&newOperands[numOperands]) OpOperand(owner);
367 
368  // If the current storage is dynamic, free it.
369  if (isStorageDynamic)
370  free(operandStorage);
371 
372  // Update the storage representation to use the new dynamic storage.
373  operandStorage = newOperandStorage;
374  capacity = newCapacity;
375  isStorageDynamic = true;
376  return newOperands;
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // Operation Value-Iterators
381 //===----------------------------------------------------------------------===//
382 
383 //===----------------------------------------------------------------------===//
384 // OperandRange
385 //===----------------------------------------------------------------------===//
386 
388  assert(!empty() && "range must not be empty");
389  return base->getOperandNumber();
390 }
391 
393  return OperandRangeRange(*this, segmentSizes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // OperandRangeRange
398 //===----------------------------------------------------------------------===//
399 
401  Attribute operandSegments)
402  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
403  llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
404 }
405 
407  const OwnerT &owner = getBase();
408  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
409  return OperandRange(owner.first,
410  std::accumulate(sizeData.begin(), sizeData.end(), 0));
411 }
412 
413 OperandRange OperandRangeRange::dereference(const OwnerT &object,
414  ptrdiff_t index) {
415  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
416  uint32_t startIndex =
417  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
418  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // MutableOperandRange
423 //===----------------------------------------------------------------------===//
424 
425 /// Construct a new mutable range from the given operand, operand start index,
426 /// and range length.
428  Operation *owner, unsigned start, unsigned length,
429  ArrayRef<OperandSegment> operandSegments)
430  : owner(owner), start(start), length(length),
431  operandSegments(operandSegments) {
432  assert((start + length) <= owner->getNumOperands() && "invalid range");
433 }
435  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
436 
437 /// Construct a new mutable range for the given OpOperand.
439  : MutableOperandRange(opOperand.getOwner(),
440  /*start=*/opOperand.getOperandNumber(),
441  /*length=*/1) {}
442 
443 /// Slice this range into a sub range, with the additional operand segment.
445 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
446  std::optional<OperandSegment> segment) const {
447  assert((subStart + subLen) <= length && "invalid sub-range");
448  MutableOperandRange subSlice(owner, start + subStart, subLen,
449  operandSegments);
450  if (segment)
451  subSlice.operandSegments.push_back(*segment);
452  return subSlice;
453 }
454 
455 /// Append the given values to the range.
457  if (values.empty())
458  return;
459  owner->insertOperands(start + length, values);
460  updateLength(length + values.size());
461 }
462 
463 /// Assign this range to the given values.
465  owner->setOperands(start, length, values);
466  if (length != values.size())
467  updateLength(/*newLength=*/values.size());
468 }
469 
470 /// Assign the range to the given value.
472  if (length == 1) {
473  owner->setOperand(start, value);
474  } else {
475  owner->setOperands(start, length, value);
476  updateLength(/*newLength=*/1);
477  }
478 }
479 
480 /// Erase the operands within the given sub-range.
481 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
482  assert((subStart + subLen) <= length && "invalid sub-range");
483  if (length == 0)
484  return;
485  owner->eraseOperands(start + subStart, subLen);
486  updateLength(length - subLen);
487 }
488 
489 /// Clear this range and erase all of the operands.
491  if (length != 0) {
492  owner->eraseOperands(start, length);
493  updateLength(/*newLength=*/0);
494  }
495 }
496 
497 /// Explicit conversion to an OperandRange.
499  return owner->getOperands().slice(start, length);
500 }
501 
502 /// Allow implicit conversion to an OperandRange.
503 MutableOperandRange::operator OperandRange() const {
504  return getAsOperandRange();
505 }
506 
507 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
508  return owner->getOpOperands().slice(start, length);
509 }
510 
513  return MutableOperandRangeRange(*this, segmentSizes);
514 }
515 
516 /// Update the length of this range to the one provided.
517 void MutableOperandRange::updateLength(unsigned newLength) {
518  int32_t diff = int32_t(newLength) - int32_t(length);
519  length = newLength;
520 
521  // Update any of the provided segment attributes.
522  for (OperandSegment &segment : operandSegments) {
523  auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
524  SmallVector<int32_t, 8> segments(attr.asArrayRef());
525  segments[segment.first] += diff;
526  segment.second.setValue(
527  DenseI32ArrayAttr::get(attr.getContext(), segments));
528  owner->setAttr(segment.second.getName(), segment.second.getValue());
529  }
530 }
531 
533  assert(index < length && "index is out of bounds");
534  return owner->getOpOperand(start + index);
535 }
536 
538  return owner->getOpOperands().slice(start, length).begin();
539 }
540 
542  return owner->getOpOperands().slice(start, length).end();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // MutableOperandRangeRange
547 //===----------------------------------------------------------------------===//
548 
550  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
552  OwnerT(operands, operandSegmentAttr), 0,
553  llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
554 }
555 
557  return getBase().first;
558 }
559 
560 MutableOperandRangeRange::operator OperandRangeRange() const {
561  return OperandRangeRange(getBase().first, getBase().second.getValue());
562 }
563 
564 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
565  ptrdiff_t index) {
566  ArrayRef<int32_t> sizeData =
567  llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
568  uint32_t startIndex =
569  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
570  return object.first.slice(
571  startIndex, *(sizeData.begin() + index),
572  MutableOperandRange::OperandSegment(index, object.second));
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // ResultRange
577 //===----------------------------------------------------------------------===//
578 
580  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
581  1) {}
582 
584  return {use_begin(), use_end()};
585 }
587  return use_iterator(*this);
588 }
590  return use_iterator(*this, /*end=*/true);
591 }
593  return {user_begin(), user_end()};
594 }
596  return user_iterator(use_begin());
597 }
599  return user_iterator(use_end());
600 }
601 
603  : it(end ? results.end() : results.begin()), endIt(results.end()) {
604  // Only initialize current use if there are results/can be uses.
605  if (it != endIt)
606  skipOverResultsWithNoUsers();
607 }
608 
610  // We increment over uses, if we reach the last use then move to next
611  // result.
612  if (use != (*it).use_end())
613  ++use;
614  if (use == (*it).use_end()) {
615  ++it;
616  skipOverResultsWithNoUsers();
617  }
618  return *this;
619 }
620 
621 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
622  while (it != endIt && (*it).use_empty())
623  ++it;
624 
625  // If we are at the last result, then set use to first use of
626  // first result (sentinel value used for end).
627  if (it == endIt)
628  use = {};
629  else
630  use = (*it).use_begin();
631 }
632 
635 }
636 
638  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
639  replaceUsesWithIf(op->getResults(), shouldReplace);
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // ValueRange
644 //===----------------------------------------------------------------------===//
645 
647  : ValueRange(values.data(), values.size()) {}
649  : ValueRange(values.begin().getBase(), values.size()) {}
651  : ValueRange(values.getBase(), values.size()) {}
652 
653 /// See `llvm::detail::indexed_accessor_range_base` for details.
654 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
655  ptrdiff_t index) {
656  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
657  return {value + index};
658  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
659  return {operand + index};
660  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
661 }
662 /// See `llvm::detail::indexed_accessor_range_base` for details.
663 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
664  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
665  return value[index];
666  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
667  return operand[index].get();
668  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // Operation Equivalency
673 //===----------------------------------------------------------------------===//
674 
676  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
677  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
678  // Hash operations based upon their:
679  // - Operation Name
680  // - Attributes
681  // - Result Types
682  DictionaryAttr dictAttrs;
683  if (!(flags & Flags::IgnoreDiscardableAttrs))
684  dictAttrs = op->getRawDictionaryAttrs();
685  llvm::hash_code hash =
686  llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
687  if (!(flags & Flags::IgnoreProperties))
688  hash = llvm::hash_combine(hash, op->hashProperties());
689 
690  // - Location if required
691  if (!(flags & Flags::IgnoreLocations))
692  hash = llvm::hash_combine(hash, op->getLoc());
693 
694  // - Operands
696  op->getNumOperands() > 0) {
697  size_t operandHash = hashOperands(op->getOperand(0));
698  for (auto operand : op->getOperands().drop_front())
699  operandHash += hashOperands(operand);
700  hash = llvm::hash_combine(hash, operandHash);
701  } else {
702  for (Value operand : op->getOperands())
703  hash = llvm::hash_combine(hash, hashOperands(operand));
704  }
705 
706  // - Results
707  for (Value result : op->getResults())
708  hash = llvm::hash_combine(hash, hashResults(result));
709  return hash;
710 }
711 
713  Region *lhs, Region *rhs,
714  function_ref<LogicalResult(Value, Value)> checkEquivalent,
715  function_ref<void(Value, Value)> markEquivalent,
717  function_ref<LogicalResult(ValueRange, ValueRange)>
718  checkCommutativeEquivalent) {
719  DenseMap<Block *, Block *> blocksMap;
720  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
721  // Check block arguments.
722  if (lBlock.getNumArguments() != rBlock.getNumArguments())
723  return false;
724 
725  // Map the two blocks.
726  auto insertion = blocksMap.insert({&lBlock, &rBlock});
727  if (insertion.first->getSecond() != &rBlock)
728  return false;
729 
730  for (auto argPair :
731  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
732  Value curArg = std::get<0>(argPair);
733  Value otherArg = std::get<1>(argPair);
734  if (curArg.getType() != otherArg.getType())
735  return false;
736  if (!(flags & OperationEquivalence::IgnoreLocations) &&
737  curArg.getLoc() != otherArg.getLoc())
738  return false;
739  // Corresponding bbArgs are equivalent.
740  if (markEquivalent)
741  markEquivalent(curArg, otherArg);
742  }
743 
744  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
745  // Check for op equality (recursively).
746  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
747  markEquivalent, flags,
748  checkCommutativeEquivalent))
749  return false;
750  // Check successor mapping.
751  for (auto successorsPair :
752  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
753  Block *curSuccessor = std::get<0>(successorsPair);
754  Block *otherSuccessor = std::get<1>(successorsPair);
755  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
756  if (insertion.first->getSecond() != otherSuccessor)
757  return false;
758  }
759  return true;
760  };
761  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
762  };
763  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
764 }
765 
766 // Value equivalence cache to be used with `isRegionEquivalentTo` and
767 // `isEquivalentTo`.
770  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
771  return success(lhsValue == rhsValue ||
772  equivalentValues.lookup(lhsValue) == rhsValue);
773  }
774  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
775  ValueRange rhsRange) {
776  // Handle simple case where sizes mismatch.
777  if (lhsRange.size() != rhsRange.size())
778  return failure();
779 
780  // Handle where operands in order are equivalent.
781  auto lhsIt = lhsRange.begin();
782  auto rhsIt = rhsRange.begin();
783  for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
784  if (failed(checkEquivalent(*lhsIt, *rhsIt)))
785  break;
786  }
787  if (lhsIt == lhsRange.end())
788  return success();
789 
790  // Handle another simple case where operands are just a permutation.
791  // Note: This is not sufficient, this handles simple cases relatively
792  // cheaply.
793  auto sortValues = [](ValueRange values) {
794  SmallVector<Value> sortedValues = llvm::to_vector(values);
795  llvm::sort(sortedValues, [](Value a, Value b) {
796  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
797  });
798  return sortedValues;
799  };
800  auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
801  auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
802  return success(lhsSorted == rhsSorted);
803  }
804  void markEquivalent(Value lhsResult, Value rhsResult) {
805  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
806  // Make sure that the value was not already marked equivalent to some other
807  // value.
808  (void)insertion;
809  assert(insertion.first->second == rhsResult &&
810  "inconsistent OperationEquivalence state");
811  }
812 };
813 
814 /*static*/ bool
817  ValueEquivalenceCache cache;
818  return isRegionEquivalentTo(
819  lhs, rhs,
820  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
821  return cache.checkEquivalent(lhsValue, rhsValue);
822  },
823  [&](Value lhsResult, Value rhsResult) {
824  cache.markEquivalent(lhsResult, rhsResult);
825  },
826  flags,
827  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
828  return cache.checkCommutativeEquivalent(lhs, rhs);
829  });
830 }
831 
833  Operation *lhs, Operation *rhs,
834  function_ref<LogicalResult(Value, Value)> checkEquivalent,
835  function_ref<void(Value, Value)> markEquivalent, Flags flags,
836  function_ref<LogicalResult(ValueRange, ValueRange)>
837  checkCommutativeEquivalent) {
838  if (lhs == rhs)
839  return true;
840 
841  // 1. Compare the operation properties.
842  if (!(flags & IgnoreDiscardableAttrs) &&
844  return false;
845 
846  if (lhs->getName() != rhs->getName() ||
847  lhs->getNumRegions() != rhs->getNumRegions() ||
848  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
849  lhs->getNumOperands() != rhs->getNumOperands() ||
850  lhs->getNumResults() != rhs->getNumResults())
851  return false;
852  if (!(flags & IgnoreProperties) &&
854  rhs->getPropertiesStorage())))
855  return false;
856  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
857  return false;
858 
859  // 2. Compare operands.
860  if (checkCommutativeEquivalent &&
862  auto lhsRange = lhs->getOperands();
863  auto rhsRange = rhs->getOperands();
864  if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
865  return false;
866  } else {
867  // Check pair wise for equivalence.
868  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
869  Value curArg = std::get<0>(operandPair);
870  Value otherArg = std::get<1>(operandPair);
871  if (curArg == otherArg)
872  continue;
873  if (curArg.getType() != otherArg.getType())
874  return false;
875  if (failed(checkEquivalent(curArg, otherArg)))
876  return false;
877  }
878  }
879 
880  // 3. Compare result types and mark results as equivalent.
881  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
882  Value curArg = std::get<0>(resultPair);
883  Value otherArg = std::get<1>(resultPair);
884  if (curArg.getType() != otherArg.getType())
885  return false;
886  if (markEquivalent)
887  markEquivalent(curArg, otherArg);
888  }
889 
890  // 4. Compare regions.
891  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
892  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
893  &std::get<1>(regionPair), checkEquivalent,
894  markEquivalent, flags))
895  return false;
896 
897  return true;
898 }
899 
901  Operation *rhs,
902  Flags flags) {
903  ValueEquivalenceCache cache;
905  lhs, rhs,
906  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
907  return cache.checkEquivalent(lhsValue, rhsValue);
908  },
909  [&](Value lhsResult, Value rhsResult) {
910  cache.markEquivalent(lhsResult, rhsResult);
911  },
912  flags,
913  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
914  return cache.checkCommutativeEquivalent(lhs, rhs);
915  });
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // OperationFingerPrint
920 //===----------------------------------------------------------------------===//
921 
922 template <typename T>
923 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
924  hasher.update(
925  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
926 }
927 
929  bool includeNested) {
930  llvm::SHA1 hasher;
931 
932  // Helper function that hashes an operation based on its mutable bits:
933  auto addOperationToHash = [&](Operation *op) {
934  // - Operation pointer
935  addDataToHash(hasher, op);
936  // - Parent operation pointer (to take into account the nesting structure)
937  if (op != topOp)
938  addDataToHash(hasher, op->getParentOp());
939  // - Attributes
940  addDataToHash(hasher, op->getRawDictionaryAttrs());
941  // - Properties
942  addDataToHash(hasher, op->hashProperties());
943  // - Blocks in Regions
944  for (Region &region : op->getRegions()) {
945  for (Block &block : region) {
946  addDataToHash(hasher, &block);
947  for (BlockArgument arg : block.getArguments())
948  addDataToHash(hasher, arg);
949  }
950  }
951  // - Location
952  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
953  // - Operands
954  for (Value operand : op->getOperands())
955  addDataToHash(hasher, operand);
956  // - Successors
957  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
958  addDataToHash(hasher, op->getSuccessor(i));
959  // - Result types
960  for (Type t : op->getResultTypes())
961  addDataToHash(hasher, t);
962  };
963 
964  if (includeNested)
965  topOp->walk(addOperationToHash);
966  else
967  addOperationToHash(topOp);
968 
969  hash = hasher.result();
970 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:309
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
BlockArgListType getArguments()
Definition: Block.h:87
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:210
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments={})
Construct a new mutable range from the given operand, operand start index, and range length.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:123
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:255
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
Block * getSuccessor(unsigned index)
Definition: Operation.h:708
unsigned getNumSuccessors()
Definition: Operation.h:706
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:360
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:509
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
Definition: Operation.cpp:354
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
SuccessorRange getSuccessors()
Definition: Operation.h:703
result_range getResults()
Definition: Operation.h:415
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
Definition: Operation.cpp:369
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:350
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:322
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:303
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:286
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:267
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:392
ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
Definition: ValueRange.h:400
An iterator over the users of an IRObject.
Definition: UseDefLists.h:344
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:233
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const