MLIR  22.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/Support/SHA1.h"
19 #include <numeric>
20 #include <optional>
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // NamedAttrList
26 //===----------------------------------------------------------------------===//
27 
29  assign(attributes.begin(), attributes.end());
30 }
31 
32 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
33  : NamedAttrList(attributes ? attributes.getValue()
34  : ArrayRef<NamedAttribute>()) {
35  dictionarySorted.setPointerAndInt(attributes, true);
36 }
37 
39  assign(inStart, inEnd);
40 }
41 
43 
44 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
45  std::optional<NamedAttribute> duplicate =
46  DictionaryAttr::findDuplicate(attrs, isSorted());
47  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
48  // state.
49  if (!isSorted())
50  dictionarySorted.setPointerAndInt(nullptr, true);
51  return duplicate;
52 }
53 
54 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
55  if (!isSorted()) {
56  DictionaryAttr::sortInPlace(attrs);
57  dictionarySorted.setPointerAndInt(nullptr, true);
58  }
59  if (!dictionarySorted.getPointer())
60  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
61  return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
62 }
63 
64 /// Replaces the attributes with new list of attributes.
66  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
67  dictionarySorted.setPointerAndInt(nullptr, true);
68 }
69 
71  if (isSorted())
72  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
73  dictionarySorted.setPointer(nullptr);
74  attrs.push_back(newAttribute);
75 }
76 
77 /// Return the specified attribute if present, null otherwise.
78 Attribute NamedAttrList::get(StringRef name) const {
79  auto it = findAttr(*this, name);
80  return it.second ? it.first->getValue() : Attribute();
81 }
82 Attribute NamedAttrList::get(StringAttr name) const {
83  auto it = findAttr(*this, name);
84  return it.second ? it.first->getValue() : Attribute();
85 }
86 
87 /// Return the specified named attribute if present, std::nullopt otherwise.
88 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
89  auto it = findAttr(*this, name);
90  return it.second ? *it.first : std::optional<NamedAttribute>();
91 }
92 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
93  auto it = findAttr(*this, name);
94  return it.second ? *it.first : std::optional<NamedAttribute>();
95 }
96 
97 /// If the an attribute exists with the specified name, change it to the new
98 /// value. Otherwise, add a new attribute with the specified name/value.
99 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
100  assert(value && "attributes may never be null");
101 
102  // Look for an existing attribute with the given name, and set its value
103  // in-place. Return the previous value of the attribute, if there was one.
104  auto it = findAttr(*this, name);
105  if (it.second) {
106  // Update the existing attribute by swapping out the old value for the new
107  // value. Return the old value.
108  Attribute oldValue = it.first->getValue();
109  if (it.first->getValue() != value) {
110  it.first->setValue(value);
111 
112  // If the attributes have changed, the dictionary is invalidated.
113  dictionarySorted.setPointer(nullptr);
114  }
115  return oldValue;
116  }
117  // Perform a string lookup to insert the new attribute into its sorted
118  // position.
119  if (isSorted())
120  it = findAttr(*this, name.strref());
121  attrs.insert(it.first, {name, value});
122  // Invalidate the dictionary. Return null as there was no previous value.
123  dictionarySorted.setPointer(nullptr);
124  return Attribute();
125 }
126 
127 Attribute NamedAttrList::set(StringRef name, Attribute value) {
128  assert(value && "attributes may never be null");
129  return set(mlir::StringAttr::get(value.getContext(), name), value);
130 }
131 
132 Attribute
133 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
134  // Erasing does not affect the sorted property.
135  Attribute attr = it->getValue();
136  attrs.erase(it);
137  dictionarySorted.setPointer(nullptr);
138  return attr;
139 }
140 
141 Attribute NamedAttrList::erase(StringAttr name) {
142  auto it = findAttr(*this, name);
143  return it.second ? eraseImpl(it.first) : Attribute();
144 }
145 
147  auto it = findAttr(*this, name);
148  return it.second ? eraseImpl(it.first) : Attribute();
149 }
150 
153  assign(rhs.begin(), rhs.end());
154  return *this;
155 }
156 
157 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
158 
159 //===----------------------------------------------------------------------===//
160 // OperationState
161 //===----------------------------------------------------------------------===//
162 
163 OperationState::OperationState(Location location, StringRef name)
164  : location(location), name(name, location->getContext()) {}
165 
167  : location(location), name(name) {}
168 
170  ValueRange operands, TypeRange types,
171  ArrayRef<NamedAttribute> attributes,
172  BlockRange successors,
173  MutableArrayRef<std::unique_ptr<Region>> regions)
174  : location(location), name(name),
175  operands(operands.begin(), operands.end()),
176  types(types.begin(), types.end()),
177  attributes(attributes.begin(), attributes.end()),
178  successors(successors.begin(), successors.end()) {
179  for (std::unique_ptr<Region> &r : regions)
180  this->regions.push_back(std::move(r));
181 }
182 OperationState::OperationState(Location location, StringRef name,
183  ValueRange operands, TypeRange types,
184  ArrayRef<NamedAttribute> attributes,
185  BlockRange successors,
186  MutableArrayRef<std::unique_ptr<Region>> regions)
187  : OperationState(location, OperationName(name, location.getContext()),
188  operands, types, attributes, successors, regions) {}
189 
191  if (properties)
192  propertiesDeleter(properties);
193 }
194 
197  if (LLVM_UNLIKELY(propertiesAttr)) {
198  assert(!properties);
200  }
201  if (properties)
202  propertiesSetter(op->getPropertiesStorage(), properties);
203  return success();
204 }
205 
207  operands.append(newOperands.begin(), newOperands.end());
208 }
209 
211  successors.append(newSuccessors.begin(), newSuccessors.end());
212 }
213 
215  regions.emplace_back(new Region);
216  return regions.back().get();
217 }
218 
219 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
220  regions.push_back(std::move(region));
221 }
222 
224  MutableArrayRef<std::unique_ptr<Region>> regions) {
225  for (std::unique_ptr<Region> &region : regions)
226  addRegion(std::move(region));
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // OperandStorage
231 //===----------------------------------------------------------------------===//
232 
234  OpOperand *trailingOperands,
235  ValueRange values)
236  : isStorageDynamic(false), operandStorage(trailingOperands) {
237  numOperands = capacity = values.size();
238  for (unsigned i = 0; i < numOperands; ++i)
239  new (&operandStorage[i]) OpOperand(owner, values[i]);
240 }
241 
243  for (auto &operand : getOperands())
244  operand.~OpOperand();
245 
246  // If the storage is dynamic, deallocate it.
247  if (isStorageDynamic)
248  free(operandStorage);
249 }
250 
251 /// Replace the operands contained in the storage with the ones provided in
252 /// 'values'.
254  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
255  for (unsigned i = 0, e = values.size(); i != e; ++i)
256  storageOperands[i].set(values[i]);
257 }
258 
259 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
260 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
261 /// than the range pointed to by 'start'+'length'.
262 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
263  unsigned length, ValueRange operands) {
264  // If the new size is the same, we can update inplace.
265  unsigned newSize = operands.size();
266  if (newSize == length) {
267  MutableArrayRef<OpOperand> storageOperands = getOperands();
268  for (unsigned i = 0, e = length; i != e; ++i)
269  storageOperands[start + i].set(operands[i]);
270  return;
271  }
272  // If the new size is greater, remove the extra operands and set the rest
273  // inplace.
274  if (newSize < length) {
275  eraseOperands(start + operands.size(), length - newSize);
276  setOperands(owner, start, newSize, operands);
277  return;
278  }
279  // Otherwise, the new size is greater so we need to grow the storage.
280  auto storageOperands = resize(owner, size() + (newSize - length));
281 
282  // Shift operands to the right to make space for the new operands.
283  unsigned rotateSize = storageOperands.size() - (start + length);
284  auto rbegin = storageOperands.rbegin();
285  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
286 
287  // Update the operands inplace.
288  for (unsigned i = 0, e = operands.size(); i != e; ++i)
289  storageOperands[start + i].set(operands[i]);
290 }
291 
292 /// Erase an operand held by the storage.
293 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
294  MutableArrayRef<OpOperand> operands = getOperands();
295  assert((start + length) <= operands.size());
296  numOperands -= length;
297 
298  // Shift all operands down if the operand to remove is not at the end.
299  if (start != numOperands) {
300  auto *indexIt = std::next(operands.begin(), start);
301  std::rotate(indexIt, std::next(indexIt, length), operands.end());
302  }
303  for (unsigned i = 0; i != length; ++i)
304  operands[numOperands + i].~OpOperand();
305 }
306 
307 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
308  MutableArrayRef<OpOperand> operands = getOperands();
309  assert(eraseIndices.size() == operands.size());
310 
311  // Check that at least one operand is erased.
312  int firstErasedIndice = eraseIndices.find_first();
313  if (firstErasedIndice == -1)
314  return;
315 
316  // Shift all of the removed operands to the end, and destroy them.
317  numOperands = firstErasedIndice;
318  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
319  if (!eraseIndices.test(i))
320  operands[numOperands++] = std::move(operands[i]);
321  for (OpOperand &operand : operands.drop_front(numOperands))
322  operand.~OpOperand();
323 }
324 
325 /// Resize the storage to the given size. Returns the array containing the new
326 /// operands.
327 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
328  unsigned newSize) {
329  // If the number of operands is less than or equal to the current amount, we
330  // can just update in place.
331  MutableArrayRef<OpOperand> origOperands = getOperands();
332  if (newSize <= numOperands) {
333  // If the number of new size is less than the current, remove any extra
334  // operands.
335  for (unsigned i = newSize; i != numOperands; ++i)
336  origOperands[i].~OpOperand();
337  numOperands = newSize;
338  return origOperands.take_front(newSize);
339  }
340 
341  // If the new size is within the original inline capacity, grow inplace.
342  if (newSize <= capacity) {
343  OpOperand *opBegin = origOperands.data();
344  for (unsigned e = newSize; numOperands != e; ++numOperands)
345  new (&opBegin[numOperands]) OpOperand(owner);
346  return MutableArrayRef<OpOperand>(opBegin, newSize);
347  }
348 
349  // Otherwise, we need to allocate a new storage.
350  unsigned newCapacity =
351  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
352  OpOperand *newOperandStorage =
353  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
354 
355  // Move the current operands to the new storage.
356  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
357  std::uninitialized_move(origOperands.begin(), origOperands.end(),
358  newOperands.begin());
359 
360  // Destroy the original operands.
361  for (auto &operand : origOperands)
362  operand.~OpOperand();
363 
364  // Initialize any new operands.
365  for (unsigned e = newSize; numOperands != e; ++numOperands)
366  new (&newOperands[numOperands]) OpOperand(owner);
367 
368  // If the current storage is dynamic, free it.
369  if (isStorageDynamic)
370  free(operandStorage);
371 
372  // Update the storage representation to use the new dynamic storage.
373  operandStorage = newOperandStorage;
374  capacity = newCapacity;
375  isStorageDynamic = true;
376  return newOperands;
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // Operation Value-Iterators
381 //===----------------------------------------------------------------------===//
382 
383 //===----------------------------------------------------------------------===//
384 // OperandRange
385 //===----------------------------------------------------------------------===//
386 
388  assert(!empty() && "range must not be empty");
389  return base->getOperandNumber();
390 }
391 
393  return OperandRangeRange(*this, segmentSizes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // OperandRangeRange
398 //===----------------------------------------------------------------------===//
399 
401  Attribute operandSegments)
402  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
403  llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
404 }
405 
407  const OwnerT &owner = getBase();
408  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
409  return OperandRange(owner.first, llvm::sum_of(sizeData));
410 }
411 
412 OperandRange OperandRangeRange::dereference(const OwnerT &object,
413  ptrdiff_t index) {
414  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
415  uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
416  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // MutableOperandRange
421 //===----------------------------------------------------------------------===//
422 
423 /// Construct a new mutable range from the given operand, operand start index,
424 /// and range length.
426  Operation *owner, unsigned start, unsigned length,
427  ArrayRef<OperandSegment> operandSegments)
428  : owner(owner), start(start), length(length),
429  operandSegments(operandSegments) {
430  assert((start + length) <= owner->getNumOperands() && "invalid range");
431 }
433  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
434 
435 /// Construct a new mutable range for the given OpOperand.
437  : MutableOperandRange(opOperand.getOwner(),
438  /*start=*/opOperand.getOperandNumber(),
439  /*length=*/1) {}
440 
441 /// Slice this range into a sub range, with the additional operand segment.
443 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
444  std::optional<OperandSegment> segment) const {
445  assert((subStart + subLen) <= length && "invalid sub-range");
446  MutableOperandRange subSlice(owner, start + subStart, subLen,
447  operandSegments);
448  if (segment)
449  subSlice.operandSegments.push_back(*segment);
450  return subSlice;
451 }
452 
453 /// Append the given values to the range.
455  if (values.empty())
456  return;
457  owner->insertOperands(start + length, values);
458  updateLength(length + values.size());
459 }
460 
461 /// Assign this range to the given values.
463  owner->setOperands(start, length, values);
464  if (length != values.size())
465  updateLength(/*newLength=*/values.size());
466 }
467 
468 /// Assign the range to the given value.
470  if (length == 1) {
471  owner->setOperand(start, value);
472  } else {
473  owner->setOperands(start, length, value);
474  updateLength(/*newLength=*/1);
475  }
476 }
477 
478 /// Erase the operands within the given sub-range.
479 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
480  assert((subStart + subLen) <= length && "invalid sub-range");
481  if (length == 0)
482  return;
483  owner->eraseOperands(start + subStart, subLen);
484  updateLength(length - subLen);
485 }
486 
487 /// Clear this range and erase all of the operands.
489  if (length != 0) {
490  owner->eraseOperands(start, length);
491  updateLength(/*newLength=*/0);
492  }
493 }
494 
495 /// Explicit conversion to an OperandRange.
497  return owner->getOperands().slice(start, length);
498 }
499 
500 /// Allow implicit conversion to an OperandRange.
501 MutableOperandRange::operator OperandRange() const {
502  return getAsOperandRange();
503 }
504 
505 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
506  return owner->getOpOperands().slice(start, length);
507 }
508 
511  return MutableOperandRangeRange(*this, segmentSizes);
512 }
513 
514 /// Update the length of this range to the one provided.
515 void MutableOperandRange::updateLength(unsigned newLength) {
516  int32_t diff = int32_t(newLength) - int32_t(length);
517  length = newLength;
518 
519  // Update any of the provided segment attributes.
520  for (OperandSegment &segment : operandSegments) {
521  auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
522  SmallVector<int32_t, 8> segments(attr.asArrayRef());
523  segments[segment.first] += diff;
524  segment.second.setValue(
525  DenseI32ArrayAttr::get(attr.getContext(), segments));
526  owner->setAttr(segment.second.getName(), segment.second.getValue());
527  }
528 }
529 
531  assert(index < length && "index is out of bounds");
532  return owner->getOpOperand(start + index);
533 }
534 
536  return owner->getOpOperands().slice(start, length).begin();
537 }
538 
540  return owner->getOpOperands().slice(start, length).end();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // MutableOperandRangeRange
545 //===----------------------------------------------------------------------===//
546 
548  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
550  OwnerT(operands, operandSegmentAttr), 0,
551  llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
552 }
553 
555  return getBase().first;
556 }
557 
558 MutableOperandRangeRange::operator OperandRangeRange() const {
559  return OperandRangeRange(getBase().first, getBase().second.getValue());
560 }
561 
562 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
563  ptrdiff_t index) {
564  ArrayRef<int32_t> sizeData =
565  llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
566  uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
567  return object.first.slice(
568  startIndex, *(sizeData.begin() + index),
569  MutableOperandRange::OperandSegment(index, object.second));
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // ResultRange
574 //===----------------------------------------------------------------------===//
575 
577  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
578  1) {}
579 
581  return {use_begin(), use_end()};
582 }
584  return use_iterator(*this);
585 }
587  return use_iterator(*this, /*end=*/true);
588 }
590  return {user_begin(), user_end()};
591 }
593  return user_iterator(use_begin());
594 }
596  return user_iterator(use_end());
597 }
598 
600  : it(end ? results.end() : results.begin()), endIt(results.end()) {
601  // Only initialize current use if there are results/can be uses.
602  if (it != endIt)
603  skipOverResultsWithNoUsers();
604 }
605 
607  // We increment over uses, if we reach the last use then move to next
608  // result.
609  if (use != (*it).use_end())
610  ++use;
611  if (use == (*it).use_end()) {
612  ++it;
613  skipOverResultsWithNoUsers();
614  }
615  return *this;
616 }
617 
618 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
619  while (it != endIt && (*it).use_empty())
620  ++it;
621 
622  // If we are at the last result, then set use to first use of
623  // first result (sentinel value used for end).
624  if (it == endIt)
625  use = {};
626  else
627  use = (*it).use_begin();
628 }
629 
632 }
633 
635  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
636  replaceUsesWithIf(op->getResults(), shouldReplace);
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // ValueRange
641 //===----------------------------------------------------------------------===//
642 
644  : ValueRange(values.data(), values.size()) {}
646  : ValueRange(values.begin().getBase(), values.size()) {}
648  : ValueRange(values.getBase(), values.size()) {}
649 
650 /// See `llvm::detail::indexed_accessor_range_base` for details.
651 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
652  ptrdiff_t index) {
653  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
654  return {value + index};
655  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
656  return {operand + index};
657  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
658 }
659 /// See `llvm::detail::indexed_accessor_range_base` for details.
660 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
661  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
662  return value[index];
663  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
664  return operand[index].get();
665  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // Operation Equivalency
670 //===----------------------------------------------------------------------===//
671 
673  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
674  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
675  // Hash operations based upon their:
676  // - Operation Name
677  // - Attributes
678  // - Result Types
679  DictionaryAttr dictAttrs;
680  if (!(flags & Flags::IgnoreDiscardableAttrs))
681  dictAttrs = op->getRawDictionaryAttrs();
682  llvm::hash_code hash =
683  llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
684  if (!(flags & Flags::IgnoreProperties))
685  hash = llvm::hash_combine(hash, op->hashProperties());
686 
687  // - Location if required
688  if (!(flags & Flags::IgnoreLocations))
689  hash = llvm::hash_combine(hash, op->getLoc());
690 
691  // - Operands
693  op->getNumOperands() > 0) {
694  size_t operandHash = hashOperands(op->getOperand(0));
695  for (auto operand : op->getOperands().drop_front())
696  operandHash += hashOperands(operand);
697  hash = llvm::hash_combine(hash, operandHash);
698  } else {
699  for (Value operand : op->getOperands())
700  hash = llvm::hash_combine(hash, hashOperands(operand));
701  }
702 
703  // - Results
704  for (Value result : op->getResults())
705  hash = llvm::hash_combine(hash, hashResults(result));
706  return hash;
707 }
708 
710  Region *lhs, Region *rhs,
711  function_ref<LogicalResult(Value, Value)> checkEquivalent,
712  function_ref<void(Value, Value)> markEquivalent,
714  function_ref<LogicalResult(ValueRange, ValueRange)>
715  checkCommutativeEquivalent) {
716  DenseMap<Block *, Block *> blocksMap;
717  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
718  // Check block arguments.
719  if (lBlock.getNumArguments() != rBlock.getNumArguments())
720  return false;
721 
722  // Map the two blocks.
723  auto insertion = blocksMap.insert({&lBlock, &rBlock});
724  if (insertion.first->getSecond() != &rBlock)
725  return false;
726 
727  for (auto argPair :
728  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
729  Value curArg = std::get<0>(argPair);
730  Value otherArg = std::get<1>(argPair);
731  if (curArg.getType() != otherArg.getType())
732  return false;
733  if (!(flags & OperationEquivalence::IgnoreLocations) &&
734  curArg.getLoc() != otherArg.getLoc())
735  return false;
736  // Corresponding bbArgs are equivalent.
737  if (markEquivalent)
738  markEquivalent(curArg, otherArg);
739  }
740 
741  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
742  // Check for op equality (recursively).
743  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
744  markEquivalent, flags,
745  checkCommutativeEquivalent))
746  return false;
747  // Check successor mapping.
748  for (auto successorsPair :
749  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
750  Block *curSuccessor = std::get<0>(successorsPair);
751  Block *otherSuccessor = std::get<1>(successorsPair);
752  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
753  if (insertion.first->getSecond() != otherSuccessor)
754  return false;
755  }
756  return true;
757  };
758  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
759  };
760  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
761 }
762 
763 // Value equivalence cache to be used with `isRegionEquivalentTo` and
764 // `isEquivalentTo`.
767  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
768  return success(lhsValue == rhsValue ||
769  equivalentValues.lookup(lhsValue) == rhsValue);
770  }
771  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
772  ValueRange rhsRange) {
773  // Handle simple case where sizes mismatch.
774  if (lhsRange.size() != rhsRange.size())
775  return failure();
776 
777  // Handle where operands in order are equivalent.
778  auto lhsIt = lhsRange.begin();
779  auto rhsIt = rhsRange.begin();
780  for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
781  if (failed(checkEquivalent(*lhsIt, *rhsIt)))
782  break;
783  }
784  if (lhsIt == lhsRange.end())
785  return success();
786 
787  // Handle another simple case where operands are just a permutation.
788  // Note: This is not sufficient, this handles simple cases relatively
789  // cheaply.
790  auto sortValues = [](ValueRange values) {
791  SmallVector<Value> sortedValues = llvm::to_vector(values);
792  llvm::sort(sortedValues, [](Value a, Value b) {
793  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
794  });
795  return sortedValues;
796  };
797  auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
798  auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
799  return success(lhsSorted == rhsSorted);
800  }
801  void markEquivalent(Value lhsResult, Value rhsResult) {
802  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
803  // Make sure that the value was not already marked equivalent to some other
804  // value.
805  (void)insertion;
806  assert(insertion.first->second == rhsResult &&
807  "inconsistent OperationEquivalence state");
808  }
809 };
810 
811 /*static*/ bool
814  ValueEquivalenceCache cache;
815  return isRegionEquivalentTo(
816  lhs, rhs,
817  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
818  return cache.checkEquivalent(lhsValue, rhsValue);
819  },
820  [&](Value lhsResult, Value rhsResult) {
821  cache.markEquivalent(lhsResult, rhsResult);
822  },
823  flags,
824  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
825  return cache.checkCommutativeEquivalent(lhs, rhs);
826  });
827 }
828 
830  Operation *lhs, Operation *rhs,
831  function_ref<LogicalResult(Value, Value)> checkEquivalent,
832  function_ref<void(Value, Value)> markEquivalent, Flags flags,
833  function_ref<LogicalResult(ValueRange, ValueRange)>
834  checkCommutativeEquivalent) {
835  if (lhs == rhs)
836  return true;
837 
838  // 1. Compare the operation properties.
839  if (!(flags & IgnoreDiscardableAttrs) &&
841  return false;
842 
843  if (lhs->getName() != rhs->getName() ||
844  lhs->getNumRegions() != rhs->getNumRegions() ||
845  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
846  lhs->getNumOperands() != rhs->getNumOperands() ||
847  lhs->getNumResults() != rhs->getNumResults())
848  return false;
849  if (!(flags & IgnoreProperties) &&
851  rhs->getPropertiesStorage())))
852  return false;
853  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
854  return false;
855 
856  // 2. Compare operands.
857  if (checkCommutativeEquivalent &&
859  auto lhsRange = lhs->getOperands();
860  auto rhsRange = rhs->getOperands();
861  if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
862  return false;
863  } else {
864  // Check pair wise for equivalence.
865  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
866  Value curArg = std::get<0>(operandPair);
867  Value otherArg = std::get<1>(operandPair);
868  if (curArg == otherArg)
869  continue;
870  if (curArg.getType() != otherArg.getType())
871  return false;
872  if (failed(checkEquivalent(curArg, otherArg)))
873  return false;
874  }
875  }
876 
877  // 3. Compare result types and mark results as equivalent.
878  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
879  Value curArg = std::get<0>(resultPair);
880  Value otherArg = std::get<1>(resultPair);
881  if (curArg.getType() != otherArg.getType())
882  return false;
883  if (markEquivalent)
884  markEquivalent(curArg, otherArg);
885  }
886 
887  // 4. Compare regions.
888  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
889  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
890  &std::get<1>(regionPair), checkEquivalent,
891  markEquivalent, flags))
892  return false;
893 
894  return true;
895 }
896 
898  Operation *rhs,
899  Flags flags) {
900  ValueEquivalenceCache cache;
902  lhs, rhs,
903  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
904  return cache.checkEquivalent(lhsValue, rhsValue);
905  },
906  [&](Value lhsResult, Value rhsResult) {
907  cache.markEquivalent(lhsResult, rhsResult);
908  },
909  flags,
910  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
911  return cache.checkCommutativeEquivalent(lhs, rhs);
912  });
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // OperationFingerPrint
917 //===----------------------------------------------------------------------===//
918 
919 template <typename T>
920 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
921  hasher.update(
922  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
923 }
924 
926  bool includeNested) {
927  llvm::SHA1 hasher;
928 
929  // Helper function that hashes an operation based on its mutable bits:
930  auto addOperationToHash = [&](Operation *op) {
931  // - Operation pointer
932  addDataToHash(hasher, op);
933  // - Parent operation pointer (to take into account the nesting structure)
934  if (op != topOp)
935  addDataToHash(hasher, op->getParentOp());
936  // - Attributes
937  addDataToHash(hasher, op->getRawDictionaryAttrs());
938  // - Properties
939  addDataToHash(hasher, op->hashProperties());
940  // - Blocks in Regions
941  for (Region &region : op->getRegions()) {
942  for (Block &block : region) {
943  addDataToHash(hasher, &block);
944  for (BlockArgument arg : block.getArguments())
945  addDataToHash(hasher, arg);
946  }
947  }
948  // - Location
949  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
950  // - Operands
951  for (Value operand : op->getOperands())
952  addDataToHash(hasher, operand);
953  // - Successors
954  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
955  addDataToHash(hasher, op->getSuccessor(i));
956  // - Result types
957  for (Type t : op->getResultTypes())
958  addDataToHash(hasher, t);
959  };
960 
961  if (includeNested)
962  topOp->walk(addOperationToHash);
963  else
964  addOperationToHash(topOp);
965 
966  hash = hasher.result();
967 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:309
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
BlockArgListType getArguments()
Definition: Block.h:87
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:210
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments={})
Construct a new mutable range from the given operand, operand start index, and range length.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:123
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:457
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
Block * getSuccessor(unsigned index)
Definition: Operation.h:708
unsigned getNumSuccessors()
Definition: Operation.h:706
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:360
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:509
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
Definition: Operation.cpp:355
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
SuccessorRange getSuccessors()
Definition: Operation.h:703
result_range getResults()
Definition: Operation.h:415
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
Definition: Operation.cpp:370
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:350
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:322
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:303
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:286
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:267
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:392
ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
Definition: ValueRange.h:400
An iterator over the users of an IRObject.
Definition: UseDefLists.h:344
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:233
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const