MLIR  21.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
28 
30  assign(attributes.begin(), attributes.end());
31 }
32 
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34  : NamedAttrList(attributes ? attributes.getValue()
35  : ArrayRef<NamedAttribute>()) {
36  dictionarySorted.setPointerAndInt(attributes, true);
37 }
38 
40  assign(inStart, inEnd);
41 }
42 
44 
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46  std::optional<NamedAttribute> duplicate =
47  DictionaryAttr::findDuplicate(attrs, isSorted());
48  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49  // state.
50  if (!isSorted())
51  dictionarySorted.setPointerAndInt(nullptr, true);
52  return duplicate;
53 }
54 
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56  if (!isSorted()) {
57  DictionaryAttr::sortInPlace(attrs);
58  dictionarySorted.setPointerAndInt(nullptr, true);
59  }
60  if (!dictionarySorted.getPointer())
61  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62  return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63 }
64 
65 /// Replaces the attributes with new list of attributes.
67  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
68  dictionarySorted.setPointerAndInt(nullptr, true);
69 }
70 
72  if (isSorted())
73  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
74  dictionarySorted.setPointer(nullptr);
75  attrs.push_back(newAttribute);
76 }
77 
78 /// Return the specified attribute if present, null otherwise.
79 Attribute NamedAttrList::get(StringRef name) const {
80  auto it = findAttr(*this, name);
81  return it.second ? it.first->getValue() : Attribute();
82 }
83 Attribute NamedAttrList::get(StringAttr name) const {
84  auto it = findAttr(*this, name);
85  return it.second ? it.first->getValue() : Attribute();
86 }
87 
88 /// Return the specified named attribute if present, std::nullopt otherwise.
89 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
90  auto it = findAttr(*this, name);
91  return it.second ? *it.first : std::optional<NamedAttribute>();
92 }
93 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
94  auto it = findAttr(*this, name);
95  return it.second ? *it.first : std::optional<NamedAttribute>();
96 }
97 
98 /// If the an attribute exists with the specified name, change it to the new
99 /// value. Otherwise, add a new attribute with the specified name/value.
100 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
101  assert(value && "attributes may never be null");
102 
103  // Look for an existing attribute with the given name, and set its value
104  // in-place. Return the previous value of the attribute, if there was one.
105  auto it = findAttr(*this, name);
106  if (it.second) {
107  // Update the existing attribute by swapping out the old value for the new
108  // value. Return the old value.
109  Attribute oldValue = it.first->getValue();
110  if (it.first->getValue() != value) {
111  it.first->setValue(value);
112 
113  // If the attributes have changed, the dictionary is invalidated.
114  dictionarySorted.setPointer(nullptr);
115  }
116  return oldValue;
117  }
118  // Perform a string lookup to insert the new attribute into its sorted
119  // position.
120  if (isSorted())
121  it = findAttr(*this, name.strref());
122  attrs.insert(it.first, {name, value});
123  // Invalidate the dictionary. Return null as there was no previous value.
124  dictionarySorted.setPointer(nullptr);
125  return Attribute();
126 }
127 
128 Attribute NamedAttrList::set(StringRef name, Attribute value) {
129  assert(value && "attributes may never be null");
130  return set(mlir::StringAttr::get(value.getContext(), name), value);
131 }
132 
133 Attribute
134 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
135  // Erasing does not affect the sorted property.
136  Attribute attr = it->getValue();
137  attrs.erase(it);
138  dictionarySorted.setPointer(nullptr);
139  return attr;
140 }
141 
142 Attribute NamedAttrList::erase(StringAttr name) {
143  auto it = findAttr(*this, name);
144  return it.second ? eraseImpl(it.first) : Attribute();
145 }
146 
148  auto it = findAttr(*this, name);
149  return it.second ? eraseImpl(it.first) : Attribute();
150 }
151 
154  assign(rhs.begin(), rhs.end());
155  return *this;
156 }
157 
158 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
159 
160 //===----------------------------------------------------------------------===//
161 // OperationState
162 //===----------------------------------------------------------------------===//
163 
164 OperationState::OperationState(Location location, StringRef name)
165  : location(location), name(name, location->getContext()) {}
166 
168  : location(location), name(name) {}
169 
171  ValueRange operands, TypeRange types,
172  ArrayRef<NamedAttribute> attributes,
173  BlockRange successors,
174  MutableArrayRef<std::unique_ptr<Region>> regions)
175  : location(location), name(name),
176  operands(operands.begin(), operands.end()),
177  types(types.begin(), types.end()),
178  attributes(attributes.begin(), attributes.end()),
179  successors(successors.begin(), successors.end()) {
180  for (std::unique_ptr<Region> &r : regions)
181  this->regions.push_back(std::move(r));
182 }
183 OperationState::OperationState(Location location, StringRef name,
184  ValueRange operands, TypeRange types,
185  ArrayRef<NamedAttribute> attributes,
186  BlockRange successors,
187  MutableArrayRef<std::unique_ptr<Region>> regions)
188  : OperationState(location, OperationName(name, location.getContext()),
189  operands, types, attributes, successors, regions) {}
190 
192  if (properties)
193  propertiesDeleter(properties);
194 }
195 
198  if (LLVM_UNLIKELY(propertiesAttr)) {
199  assert(!properties);
201  }
202  if (properties)
203  propertiesSetter(op->getPropertiesStorage(), properties);
204  return success();
205 }
206 
208  operands.append(newOperands.begin(), newOperands.end());
209 }
210 
212  successors.append(newSuccessors.begin(), newSuccessors.end());
213 }
214 
216  regions.emplace_back(new Region);
217  return regions.back().get();
218 }
219 
220 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
221  regions.push_back(std::move(region));
222 }
223 
225  MutableArrayRef<std::unique_ptr<Region>> regions) {
226  for (std::unique_ptr<Region> &region : regions)
227  addRegion(std::move(region));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // OperandStorage
232 //===----------------------------------------------------------------------===//
233 
235  OpOperand *trailingOperands,
236  ValueRange values)
237  : isStorageDynamic(false), operandStorage(trailingOperands) {
238  numOperands = capacity = values.size();
239  for (unsigned i = 0; i < numOperands; ++i)
240  new (&operandStorage[i]) OpOperand(owner, values[i]);
241 }
242 
244  for (auto &operand : getOperands())
245  operand.~OpOperand();
246 
247  // If the storage is dynamic, deallocate it.
248  if (isStorageDynamic)
249  free(operandStorage);
250 }
251 
252 /// Replace the operands contained in the storage with the ones provided in
253 /// 'values'.
255  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
256  for (unsigned i = 0, e = values.size(); i != e; ++i)
257  storageOperands[i].set(values[i]);
258 }
259 
260 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
261 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
262 /// than the range pointed to by 'start'+'length'.
263 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
264  unsigned length, ValueRange operands) {
265  // If the new size is the same, we can update inplace.
266  unsigned newSize = operands.size();
267  if (newSize == length) {
268  MutableArrayRef<OpOperand> storageOperands = getOperands();
269  for (unsigned i = 0, e = length; i != e; ++i)
270  storageOperands[start + i].set(operands[i]);
271  return;
272  }
273  // If the new size is greater, remove the extra operands and set the rest
274  // inplace.
275  if (newSize < length) {
276  eraseOperands(start + operands.size(), length - newSize);
277  setOperands(owner, start, newSize, operands);
278  return;
279  }
280  // Otherwise, the new size is greater so we need to grow the storage.
281  auto storageOperands = resize(owner, size() + (newSize - length));
282 
283  // Shift operands to the right to make space for the new operands.
284  unsigned rotateSize = storageOperands.size() - (start + length);
285  auto rbegin = storageOperands.rbegin();
286  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
287 
288  // Update the operands inplace.
289  for (unsigned i = 0, e = operands.size(); i != e; ++i)
290  storageOperands[start + i].set(operands[i]);
291 }
292 
293 /// Erase an operand held by the storage.
294 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
295  MutableArrayRef<OpOperand> operands = getOperands();
296  assert((start + length) <= operands.size());
297  numOperands -= length;
298 
299  // Shift all operands down if the operand to remove is not at the end.
300  if (start != numOperands) {
301  auto *indexIt = std::next(operands.begin(), start);
302  std::rotate(indexIt, std::next(indexIt, length), operands.end());
303  }
304  for (unsigned i = 0; i != length; ++i)
305  operands[numOperands + i].~OpOperand();
306 }
307 
308 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
309  MutableArrayRef<OpOperand> operands = getOperands();
310  assert(eraseIndices.size() == operands.size());
311 
312  // Check that at least one operand is erased.
313  int firstErasedIndice = eraseIndices.find_first();
314  if (firstErasedIndice == -1)
315  return;
316 
317  // Shift all of the removed operands to the end, and destroy them.
318  numOperands = firstErasedIndice;
319  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
320  if (!eraseIndices.test(i))
321  operands[numOperands++] = std::move(operands[i]);
322  for (OpOperand &operand : operands.drop_front(numOperands))
323  operand.~OpOperand();
324 }
325 
326 /// Resize the storage to the given size. Returns the array containing the new
327 /// operands.
328 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
329  unsigned newSize) {
330  // If the number of operands is less than or equal to the current amount, we
331  // can just update in place.
332  MutableArrayRef<OpOperand> origOperands = getOperands();
333  if (newSize <= numOperands) {
334  // If the number of new size is less than the current, remove any extra
335  // operands.
336  for (unsigned i = newSize; i != numOperands; ++i)
337  origOperands[i].~OpOperand();
338  numOperands = newSize;
339  return origOperands.take_front(newSize);
340  }
341 
342  // If the new size is within the original inline capacity, grow inplace.
343  if (newSize <= capacity) {
344  OpOperand *opBegin = origOperands.data();
345  for (unsigned e = newSize; numOperands != e; ++numOperands)
346  new (&opBegin[numOperands]) OpOperand(owner);
347  return MutableArrayRef<OpOperand>(opBegin, newSize);
348  }
349 
350  // Otherwise, we need to allocate a new storage.
351  unsigned newCapacity =
352  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
353  OpOperand *newOperandStorage =
354  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
355 
356  // Move the current operands to the new storage.
357  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
358  std::uninitialized_move(origOperands.begin(), origOperands.end(),
359  newOperands.begin());
360 
361  // Destroy the original operands.
362  for (auto &operand : origOperands)
363  operand.~OpOperand();
364 
365  // Initialize any new operands.
366  for (unsigned e = newSize; numOperands != e; ++numOperands)
367  new (&newOperands[numOperands]) OpOperand(owner);
368 
369  // If the current storage is dynamic, free it.
370  if (isStorageDynamic)
371  free(operandStorage);
372 
373  // Update the storage representation to use the new dynamic storage.
374  operandStorage = newOperandStorage;
375  capacity = newCapacity;
376  isStorageDynamic = true;
377  return newOperands;
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // Operation Value-Iterators
382 //===----------------------------------------------------------------------===//
383 
384 //===----------------------------------------------------------------------===//
385 // OperandRange
386 //===----------------------------------------------------------------------===//
387 
389  assert(!empty() && "range must not be empty");
390  return base->getOperandNumber();
391 }
392 
394  return OperandRangeRange(*this, segmentSizes);
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // OperandRangeRange
399 //===----------------------------------------------------------------------===//
400 
402  Attribute operandSegments)
403  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
404  llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
405 }
406 
408  const OwnerT &owner = getBase();
409  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
410  return OperandRange(owner.first,
411  std::accumulate(sizeData.begin(), sizeData.end(), 0));
412 }
413 
414 OperandRange OperandRangeRange::dereference(const OwnerT &object,
415  ptrdiff_t index) {
416  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
417  uint32_t startIndex =
418  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
419  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // MutableOperandRange
424 //===----------------------------------------------------------------------===//
425 
426 /// Construct a new mutable range from the given operand, operand start index,
427 /// and range length.
429  Operation *owner, unsigned start, unsigned length,
430  ArrayRef<OperandSegment> operandSegments)
431  : owner(owner), start(start), length(length),
432  operandSegments(operandSegments) {
433  assert((start + length) <= owner->getNumOperands() && "invalid range");
434 }
436  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
437 
438 /// Construct a new mutable range for the given OpOperand.
440  : MutableOperandRange(opOperand.getOwner(),
441  /*start=*/opOperand.getOperandNumber(),
442  /*length=*/1) {}
443 
444 /// Slice this range into a sub range, with the additional operand segment.
446 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
447  std::optional<OperandSegment> segment) const {
448  assert((subStart + subLen) <= length && "invalid sub-range");
449  MutableOperandRange subSlice(owner, start + subStart, subLen,
450  operandSegments);
451  if (segment)
452  subSlice.operandSegments.push_back(*segment);
453  return subSlice;
454 }
455 
456 /// Append the given values to the range.
458  if (values.empty())
459  return;
460  owner->insertOperands(start + length, values);
461  updateLength(length + values.size());
462 }
463 
464 /// Assign this range to the given values.
466  owner->setOperands(start, length, values);
467  if (length != values.size())
468  updateLength(/*newLength=*/values.size());
469 }
470 
471 /// Assign the range to the given value.
473  if (length == 1) {
474  owner->setOperand(start, value);
475  } else {
476  owner->setOperands(start, length, value);
477  updateLength(/*newLength=*/1);
478  }
479 }
480 
481 /// Erase the operands within the given sub-range.
482 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
483  assert((subStart + subLen) <= length && "invalid sub-range");
484  if (length == 0)
485  return;
486  owner->eraseOperands(start + subStart, subLen);
487  updateLength(length - subLen);
488 }
489 
490 /// Clear this range and erase all of the operands.
492  if (length != 0) {
493  owner->eraseOperands(start, length);
494  updateLength(/*newLength=*/0);
495  }
496 }
497 
498 /// Explicit conversion to an OperandRange.
500  return owner->getOperands().slice(start, length);
501 }
502 
503 /// Allow implicit conversion to an OperandRange.
504 MutableOperandRange::operator OperandRange() const {
505  return getAsOperandRange();
506 }
507 
508 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
509  return owner->getOpOperands().slice(start, length);
510 }
511 
514  return MutableOperandRangeRange(*this, segmentSizes);
515 }
516 
517 /// Update the length of this range to the one provided.
518 void MutableOperandRange::updateLength(unsigned newLength) {
519  int32_t diff = int32_t(newLength) - int32_t(length);
520  length = newLength;
521 
522  // Update any of the provided segment attributes.
523  for (OperandSegment &segment : operandSegments) {
524  auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
525  SmallVector<int32_t, 8> segments(attr.asArrayRef());
526  segments[segment.first] += diff;
527  segment.second.setValue(
528  DenseI32ArrayAttr::get(attr.getContext(), segments));
529  owner->setAttr(segment.second.getName(), segment.second.getValue());
530  }
531 }
532 
534  assert(index < length && "index is out of bounds");
535  return owner->getOpOperand(start + index);
536 }
537 
539  return owner->getOpOperands().slice(start, length).begin();
540 }
541 
543  return owner->getOpOperands().slice(start, length).end();
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // MutableOperandRangeRange
548 //===----------------------------------------------------------------------===//
549 
551  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
553  OwnerT(operands, operandSegmentAttr), 0,
554  llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
555 }
556 
558  return getBase().first;
559 }
560 
561 MutableOperandRangeRange::operator OperandRangeRange() const {
562  return OperandRangeRange(getBase().first, getBase().second.getValue());
563 }
564 
565 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
566  ptrdiff_t index) {
567  ArrayRef<int32_t> sizeData =
568  llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
569  uint32_t startIndex =
570  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
571  return object.first.slice(
572  startIndex, *(sizeData.begin() + index),
573  MutableOperandRange::OperandSegment(index, object.second));
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // ResultRange
578 //===----------------------------------------------------------------------===//
579 
581  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
582  1) {}
583 
585  return {use_begin(), use_end()};
586 }
588  return use_iterator(*this);
589 }
591  return use_iterator(*this, /*end=*/true);
592 }
594  return {user_begin(), user_end()};
595 }
597  return user_iterator(use_begin());
598 }
600  return user_iterator(use_end());
601 }
602 
604  : it(end ? results.end() : results.begin()), endIt(results.end()) {
605  // Only initialize current use if there are results/can be uses.
606  if (it != endIt)
607  skipOverResultsWithNoUsers();
608 }
609 
611  // We increment over uses, if we reach the last use then move to next
612  // result.
613  if (use != (*it).use_end())
614  ++use;
615  if (use == (*it).use_end()) {
616  ++it;
617  skipOverResultsWithNoUsers();
618  }
619  return *this;
620 }
621 
622 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
623  while (it != endIt && (*it).use_empty())
624  ++it;
625 
626  // If we are at the last result, then set use to first use of
627  // first result (sentinel value used for end).
628  if (it == endIt)
629  use = {};
630  else
631  use = (*it).use_begin();
632 }
633 
636 }
637 
639  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
640  replaceUsesWithIf(op->getResults(), shouldReplace);
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // ValueRange
645 //===----------------------------------------------------------------------===//
646 
648  : ValueRange(values.data(), values.size()) {}
650  : ValueRange(values.begin().getBase(), values.size()) {}
652  : ValueRange(values.getBase(), values.size()) {}
653 
654 /// See `llvm::detail::indexed_accessor_range_base` for details.
655 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
656  ptrdiff_t index) {
657  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
658  return {value + index};
659  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
660  return {operand + index};
661  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
662 }
663 /// See `llvm::detail::indexed_accessor_range_base` for details.
664 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
665  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
666  return value[index];
667  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
668  return operand[index].get();
669  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // Operation Equivalency
674 //===----------------------------------------------------------------------===//
675 
677  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
678  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
679  // Hash operations based upon their:
680  // - Operation Name
681  // - Attributes
682  // - Result Types
683  DictionaryAttr dictAttrs;
684  if (!(flags & Flags::IgnoreDiscardableAttrs))
685  dictAttrs = op->getRawDictionaryAttrs();
686  llvm::hash_code hash =
687  llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
688  if (!(flags & Flags::IgnoreProperties))
689  hash = llvm::hash_combine(hash, op->hashProperties());
690 
691  // - Location if required
692  if (!(flags & Flags::IgnoreLocations))
693  hash = llvm::hash_combine(hash, op->getLoc());
694 
695  // - Operands
697  op->getNumOperands() > 0) {
698  size_t operandHash = hashOperands(op->getOperand(0));
699  for (auto operand : op->getOperands().drop_front())
700  operandHash += hashOperands(operand);
701  hash = llvm::hash_combine(hash, operandHash);
702  } else {
703  for (Value operand : op->getOperands())
704  hash = llvm::hash_combine(hash, hashOperands(operand));
705  }
706 
707  // - Results
708  for (Value result : op->getResults())
709  hash = llvm::hash_combine(hash, hashResults(result));
710  return hash;
711 }
712 
714  Region *lhs, Region *rhs,
715  function_ref<LogicalResult(Value, Value)> checkEquivalent,
716  function_ref<void(Value, Value)> markEquivalent,
718  function_ref<LogicalResult(ValueRange, ValueRange)>
719  checkCommutativeEquivalent) {
720  DenseMap<Block *, Block *> blocksMap;
721  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
722  // Check block arguments.
723  if (lBlock.getNumArguments() != rBlock.getNumArguments())
724  return false;
725 
726  // Map the two blocks.
727  auto insertion = blocksMap.insert({&lBlock, &rBlock});
728  if (insertion.first->getSecond() != &rBlock)
729  return false;
730 
731  for (auto argPair :
732  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
733  Value curArg = std::get<0>(argPair);
734  Value otherArg = std::get<1>(argPair);
735  if (curArg.getType() != otherArg.getType())
736  return false;
737  if (!(flags & OperationEquivalence::IgnoreLocations) &&
738  curArg.getLoc() != otherArg.getLoc())
739  return false;
740  // Corresponding bbArgs are equivalent.
741  if (markEquivalent)
742  markEquivalent(curArg, otherArg);
743  }
744 
745  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
746  // Check for op equality (recursively).
747  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
748  markEquivalent, flags,
749  checkCommutativeEquivalent))
750  return false;
751  // Check successor mapping.
752  for (auto successorsPair :
753  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
754  Block *curSuccessor = std::get<0>(successorsPair);
755  Block *otherSuccessor = std::get<1>(successorsPair);
756  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
757  if (insertion.first->getSecond() != otherSuccessor)
758  return false;
759  }
760  return true;
761  };
762  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
763  };
764  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
765 }
766 
767 // Value equivalence cache to be used with `isRegionEquivalentTo` and
768 // `isEquivalentTo`.
771  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
772  return success(lhsValue == rhsValue ||
773  equivalentValues.lookup(lhsValue) == rhsValue);
774  }
775  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
776  ValueRange rhsRange) {
777  // Handle simple case where sizes mismatch.
778  if (lhsRange.size() != rhsRange.size())
779  return failure();
780 
781  // Handle where operands in order are equivalent.
782  auto lhsIt = lhsRange.begin();
783  auto rhsIt = rhsRange.begin();
784  for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
785  if (failed(checkEquivalent(*lhsIt, *rhsIt)))
786  break;
787  }
788  if (lhsIt == lhsRange.end())
789  return success();
790 
791  // Handle another simple case where operands are just a permutation.
792  // Note: This is not sufficient, this handles simple cases relatively
793  // cheaply.
794  auto sortValues = [](ValueRange values) {
795  SmallVector<Value> sortedValues = llvm::to_vector(values);
796  llvm::sort(sortedValues, [](Value a, Value b) {
797  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
798  });
799  return sortedValues;
800  };
801  auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
802  auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
803  return success(lhsSorted == rhsSorted);
804  }
805  void markEquivalent(Value lhsResult, Value rhsResult) {
806  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
807  // Make sure that the value was not already marked equivalent to some other
808  // value.
809  (void)insertion;
810  assert(insertion.first->second == rhsResult &&
811  "inconsistent OperationEquivalence state");
812  }
813 };
814 
815 /*static*/ bool
818  ValueEquivalenceCache cache;
819  return isRegionEquivalentTo(
820  lhs, rhs,
821  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
822  return cache.checkEquivalent(lhsValue, rhsValue);
823  },
824  [&](Value lhsResult, Value rhsResult) {
825  cache.markEquivalent(lhsResult, rhsResult);
826  },
827  flags,
828  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
829  return cache.checkCommutativeEquivalent(lhs, rhs);
830  });
831 }
832 
834  Operation *lhs, Operation *rhs,
835  function_ref<LogicalResult(Value, Value)> checkEquivalent,
836  function_ref<void(Value, Value)> markEquivalent, Flags flags,
837  function_ref<LogicalResult(ValueRange, ValueRange)>
838  checkCommutativeEquivalent) {
839  if (lhs == rhs)
840  return true;
841 
842  // 1. Compare the operation properties.
843  if (!(flags & IgnoreDiscardableAttrs) &&
845  return false;
846 
847  if (lhs->getName() != rhs->getName() ||
848  lhs->getNumRegions() != rhs->getNumRegions() ||
849  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
850  lhs->getNumOperands() != rhs->getNumOperands() ||
851  lhs->getNumResults() != rhs->getNumResults())
852  return false;
853  if (!(flags & IgnoreProperties) &&
855  rhs->getPropertiesStorage())))
856  return false;
857  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
858  return false;
859 
860  // 2. Compare operands.
861  if (checkCommutativeEquivalent &&
863  auto lhsRange = lhs->getOperands();
864  auto rhsRange = rhs->getOperands();
865  if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
866  return false;
867  } else {
868  // Check pair wise for equivalence.
869  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
870  Value curArg = std::get<0>(operandPair);
871  Value otherArg = std::get<1>(operandPair);
872  if (curArg == otherArg)
873  continue;
874  if (curArg.getType() != otherArg.getType())
875  return false;
876  if (failed(checkEquivalent(curArg, otherArg)))
877  return false;
878  }
879  }
880 
881  // 3. Compare result types and mark results as equivalent.
882  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
883  Value curArg = std::get<0>(resultPair);
884  Value otherArg = std::get<1>(resultPair);
885  if (curArg.getType() != otherArg.getType())
886  return false;
887  if (markEquivalent)
888  markEquivalent(curArg, otherArg);
889  }
890 
891  // 4. Compare regions.
892  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
893  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
894  &std::get<1>(regionPair), checkEquivalent,
895  markEquivalent, flags))
896  return false;
897 
898  return true;
899 }
900 
902  Operation *rhs,
903  Flags flags) {
904  ValueEquivalenceCache cache;
906  lhs, rhs,
907  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
908  return cache.checkEquivalent(lhsValue, rhsValue);
909  },
910  [&](Value lhsResult, Value rhsResult) {
911  cache.markEquivalent(lhsResult, rhsResult);
912  },
913  flags,
914  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
915  return cache.checkCommutativeEquivalent(lhs, rhs);
916  });
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // OperationFingerPrint
921 //===----------------------------------------------------------------------===//
922 
923 template <typename T>
924 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
925  hasher.update(
926  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
927 }
928 
930  bool includeNested) {
931  llvm::SHA1 hasher;
932 
933  // Helper function that hashes an operation based on its mutable bits:
934  auto addOperationToHash = [&](Operation *op) {
935  // - Operation pointer
936  addDataToHash(hasher, op);
937  // - Parent operation pointer (to take into account the nesting structure)
938  if (op != topOp)
939  addDataToHash(hasher, op->getParentOp());
940  // - Attributes
941  addDataToHash(hasher, op->getRawDictionaryAttrs());
942  // - Properties
943  addDataToHash(hasher, op->hashProperties());
944  // - Blocks in Regions
945  for (Region &region : op->getRegions()) {
946  for (Block &block : region) {
947  addDataToHash(hasher, &block);
948  for (BlockArgument arg : block.getArguments())
949  addDataToHash(hasher, arg);
950  }
951  }
952  // - Location
953  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
954  // - Operands
955  for (Value operand : op->getOperands())
956  addDataToHash(hasher, operand);
957  // - Successors
958  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
959  addDataToHash(hasher, op->getSuccessor(i));
960  // - Result types
961  for (Type t : op->getResultTypes())
962  addDataToHash(hasher, t);
963  };
964 
965  if (includeNested)
966  topOp->walk(addOperationToHash);
967  else
968  addOperationToHash(topOp);
969 
970  hash = hasher.result();
971 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:309
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
BlockArgListType getArguments()
Definition: Block.h:87
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:210
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments={})
Construct a new mutable range from the given operand, operand start index, and range length.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:123
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
Block * getSuccessor(unsigned index)
Definition: Operation.h:708
unsigned getNumSuccessors()
Definition: Operation.h:706
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:360
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:509
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
Definition: Operation.cpp:355
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
SuccessorRange getSuccessors()
Definition: Operation.h:703
result_range getResults()
Definition: Operation.h:415
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
Definition: Operation.cpp:370
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:350
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:322
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:303
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:286
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:267
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:392
ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
Definition: ValueRange.h:400
An iterator over the users of an IRObject.
Definition: UseDefLists.h:344
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:233
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const