MLIR  18.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
28 
30  assign(attributes.begin(), attributes.end());
31 }
32 
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34  : NamedAttrList(attributes ? attributes.getValue()
35  : ArrayRef<NamedAttribute>()) {
36  dictionarySorted.setPointerAndInt(attributes, true);
37 }
38 
40  assign(inStart, inEnd);
41 }
42 
44 
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46  std::optional<NamedAttribute> duplicate =
47  DictionaryAttr::findDuplicate(attrs, isSorted());
48  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49  // state.
50  if (!isSorted())
51  dictionarySorted.setPointerAndInt(nullptr, true);
52  return duplicate;
53 }
54 
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56  if (!isSorted()) {
57  DictionaryAttr::sortInPlace(attrs);
58  dictionarySorted.setPointerAndInt(nullptr, true);
59  }
60  if (!dictionarySorted.getPointer())
61  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62  return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63 }
64 
65 /// Add an attribute with the specified name.
66 void NamedAttrList::append(StringRef name, Attribute attr) {
67  append(StringAttr::get(attr.getContext(), name), attr);
68 }
69 
70 /// Replaces the attributes with new list of attributes.
72  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
73  dictionarySorted.setPointerAndInt(nullptr, true);
74 }
75 
77  if (isSorted())
78  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
79  dictionarySorted.setPointer(nullptr);
80  attrs.push_back(newAttribute);
81 }
82 
83 /// Return the specified attribute if present, null otherwise.
84 Attribute NamedAttrList::get(StringRef name) const {
85  auto it = findAttr(*this, name);
86  return it.second ? it.first->getValue() : Attribute();
87 }
88 Attribute NamedAttrList::get(StringAttr name) const {
89  auto it = findAttr(*this, name);
90  return it.second ? it.first->getValue() : Attribute();
91 }
92 
93 /// Return the specified named attribute if present, std::nullopt otherwise.
94 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
95  auto it = findAttr(*this, name);
96  return it.second ? *it.first : std::optional<NamedAttribute>();
97 }
98 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
99  auto it = findAttr(*this, name);
100  return it.second ? *it.first : std::optional<NamedAttribute>();
101 }
102 
103 /// If the an attribute exists with the specified name, change it to the new
104 /// value. Otherwise, add a new attribute with the specified name/value.
105 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
106  assert(value && "attributes may never be null");
107 
108  // Look for an existing attribute with the given name, and set its value
109  // in-place. Return the previous value of the attribute, if there was one.
110  auto it = findAttr(*this, name);
111  if (it.second) {
112  // Update the existing attribute by swapping out the old value for the new
113  // value. Return the old value.
114  Attribute oldValue = it.first->getValue();
115  if (it.first->getValue() != value) {
116  it.first->setValue(value);
117 
118  // If the attributes have changed, the dictionary is invalidated.
119  dictionarySorted.setPointer(nullptr);
120  }
121  return oldValue;
122  }
123  // Perform a string lookup to insert the new attribute into its sorted
124  // position.
125  if (isSorted())
126  it = findAttr(*this, name.strref());
127  attrs.insert(it.first, {name, value});
128  // Invalidate the dictionary. Return null as there was no previous value.
129  dictionarySorted.setPointer(nullptr);
130  return Attribute();
131 }
132 
133 Attribute NamedAttrList::set(StringRef name, Attribute value) {
134  assert(value && "attributes may never be null");
135  return set(mlir::StringAttr::get(value.getContext(), name), value);
136 }
137 
138 Attribute
139 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
140  // Erasing does not affect the sorted property.
141  Attribute attr = it->getValue();
142  attrs.erase(it);
143  dictionarySorted.setPointer(nullptr);
144  return attr;
145 }
146 
147 Attribute NamedAttrList::erase(StringAttr name) {
148  auto it = findAttr(*this, name);
149  return it.second ? eraseImpl(it.first) : Attribute();
150 }
151 
153  auto it = findAttr(*this, name);
154  return it.second ? eraseImpl(it.first) : Attribute();
155 }
156 
159  assign(rhs.begin(), rhs.end());
160  return *this;
161 }
162 
163 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
164 
165 //===----------------------------------------------------------------------===//
166 // OperationState
167 //===----------------------------------------------------------------------===//
168 
169 OperationState::OperationState(Location location, StringRef name)
170  : location(location), name(name, location->getContext()) {}
171 
173  : location(location), name(name) {}
174 
176  ValueRange operands, TypeRange types,
177  ArrayRef<NamedAttribute> attributes,
178  BlockRange successors,
179  MutableArrayRef<std::unique_ptr<Region>> regions)
180  : location(location), name(name),
181  operands(operands.begin(), operands.end()),
182  types(types.begin(), types.end()),
183  attributes(attributes.begin(), attributes.end()),
184  successors(successors.begin(), successors.end()) {
185  for (std::unique_ptr<Region> &r : regions)
186  this->regions.push_back(std::move(r));
187 }
188 OperationState::OperationState(Location location, StringRef name,
189  ValueRange operands, TypeRange types,
190  ArrayRef<NamedAttribute> attributes,
191  BlockRange successors,
192  MutableArrayRef<std::unique_ptr<Region>> regions)
193  : OperationState(location, OperationName(name, location.getContext()),
194  operands, types, attributes, successors, regions) {}
195 
197  if (properties)
198  propertiesDeleter(properties);
199 }
200 
203  if (LLVM_UNLIKELY(propertiesAttr)) {
204  assert(!properties);
206  }
207  if (properties)
208  propertiesSetter(op->getPropertiesStorage(), properties);
209  return success();
210 }
211 
213  operands.append(newOperands.begin(), newOperands.end());
214 }
215 
217  successors.append(newSuccessors.begin(), newSuccessors.end());
218 }
219 
221  regions.emplace_back(new Region);
222  return regions.back().get();
223 }
224 
225 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
226  regions.push_back(std::move(region));
227 }
228 
230  MutableArrayRef<std::unique_ptr<Region>> regions) {
231  for (std::unique_ptr<Region> &region : regions)
232  addRegion(std::move(region));
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // OperandStorage
237 //===----------------------------------------------------------------------===//
238 
240  OpOperand *trailingOperands,
241  ValueRange values)
242  : isStorageDynamic(false), operandStorage(trailingOperands) {
243  numOperands = capacity = values.size();
244  for (unsigned i = 0; i < numOperands; ++i)
245  new (&operandStorage[i]) OpOperand(owner, values[i]);
246 }
247 
249  for (auto &operand : getOperands())
250  operand.~OpOperand();
251 
252  // If the storage is dynamic, deallocate it.
253  if (isStorageDynamic)
254  free(operandStorage);
255 }
256 
257 /// Replace the operands contained in the storage with the ones provided in
258 /// 'values'.
260  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
261  for (unsigned i = 0, e = values.size(); i != e; ++i)
262  storageOperands[i].set(values[i]);
263 }
264 
265 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
266 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
267 /// than the range pointed to by 'start'+'length'.
268 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
269  unsigned length, ValueRange operands) {
270  // If the new size is the same, we can update inplace.
271  unsigned newSize = operands.size();
272  if (newSize == length) {
273  MutableArrayRef<OpOperand> storageOperands = getOperands();
274  for (unsigned i = 0, e = length; i != e; ++i)
275  storageOperands[start + i].set(operands[i]);
276  return;
277  }
278  // If the new size is greater, remove the extra operands and set the rest
279  // inplace.
280  if (newSize < length) {
281  eraseOperands(start + operands.size(), length - newSize);
282  setOperands(owner, start, newSize, operands);
283  return;
284  }
285  // Otherwise, the new size is greater so we need to grow the storage.
286  auto storageOperands = resize(owner, size() + (newSize - length));
287 
288  // Shift operands to the right to make space for the new operands.
289  unsigned rotateSize = storageOperands.size() - (start + length);
290  auto rbegin = storageOperands.rbegin();
291  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
292 
293  // Update the operands inplace.
294  for (unsigned i = 0, e = operands.size(); i != e; ++i)
295  storageOperands[start + i].set(operands[i]);
296 }
297 
298 /// Erase an operand held by the storage.
299 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
300  MutableArrayRef<OpOperand> operands = getOperands();
301  assert((start + length) <= operands.size());
302  numOperands -= length;
303 
304  // Shift all operands down if the operand to remove is not at the end.
305  if (start != numOperands) {
306  auto *indexIt = std::next(operands.begin(), start);
307  std::rotate(indexIt, std::next(indexIt, length), operands.end());
308  }
309  for (unsigned i = 0; i != length; ++i)
310  operands[numOperands + i].~OpOperand();
311 }
312 
313 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
314  MutableArrayRef<OpOperand> operands = getOperands();
315  assert(eraseIndices.size() == operands.size());
316 
317  // Check that at least one operand is erased.
318  int firstErasedIndice = eraseIndices.find_first();
319  if (firstErasedIndice == -1)
320  return;
321 
322  // Shift all of the removed operands to the end, and destroy them.
323  numOperands = firstErasedIndice;
324  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
325  if (!eraseIndices.test(i))
326  operands[numOperands++] = std::move(operands[i]);
327  for (OpOperand &operand : operands.drop_front(numOperands))
328  operand.~OpOperand();
329 }
330 
331 /// Resize the storage to the given size. Returns the array containing the new
332 /// operands.
333 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
334  unsigned newSize) {
335  // If the number of operands is less than or equal to the current amount, we
336  // can just update in place.
337  MutableArrayRef<OpOperand> origOperands = getOperands();
338  if (newSize <= numOperands) {
339  // If the number of new size is less than the current, remove any extra
340  // operands.
341  for (unsigned i = newSize; i != numOperands; ++i)
342  origOperands[i].~OpOperand();
343  numOperands = newSize;
344  return origOperands.take_front(newSize);
345  }
346 
347  // If the new size is within the original inline capacity, grow inplace.
348  if (newSize <= capacity) {
349  OpOperand *opBegin = origOperands.data();
350  for (unsigned e = newSize; numOperands != e; ++numOperands)
351  new (&opBegin[numOperands]) OpOperand(owner);
352  return MutableArrayRef<OpOperand>(opBegin, newSize);
353  }
354 
355  // Otherwise, we need to allocate a new storage.
356  unsigned newCapacity =
357  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
358  OpOperand *newOperandStorage =
359  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
360 
361  // Move the current operands to the new storage.
362  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
363  std::uninitialized_move(origOperands.begin(), origOperands.end(),
364  newOperands.begin());
365 
366  // Destroy the original operands.
367  for (auto &operand : origOperands)
368  operand.~OpOperand();
369 
370  // Initialize any new operands.
371  for (unsigned e = newSize; numOperands != e; ++numOperands)
372  new (&newOperands[numOperands]) OpOperand(owner);
373 
374  // If the current storage is dynamic, free it.
375  if (isStorageDynamic)
376  free(operandStorage);
377 
378  // Update the storage representation to use the new dynamic storage.
379  operandStorage = newOperandStorage;
380  capacity = newCapacity;
381  isStorageDynamic = true;
382  return newOperands;
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // Operation Value-Iterators
387 //===----------------------------------------------------------------------===//
388 
389 //===----------------------------------------------------------------------===//
390 // OperandRange
391 
393  assert(!empty() && "range must not be empty");
394  return base->getOperandNumber();
395 }
396 
398  return OperandRangeRange(*this, segmentSizes);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // OperandRangeRange
403 
405  Attribute operandSegments)
406  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
407  llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
408 }
409 
411  const OwnerT &owner = getBase();
412  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
413  return OperandRange(owner.first,
414  std::accumulate(sizeData.begin(), sizeData.end(), 0));
415 }
416 
417 OperandRange OperandRangeRange::dereference(const OwnerT &object,
418  ptrdiff_t index) {
419  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
420  uint32_t startIndex =
421  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
422  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // MutableOperandRange
427 
428 /// Construct a new mutable range from the given operand, operand start index,
429 /// and range length.
431  Operation *owner, unsigned start, unsigned length,
432  ArrayRef<OperandSegment> operandSegments)
433  : owner(owner), start(start), length(length),
434  operandSegments(operandSegments.begin(), operandSegments.end()) {
435  assert((start + length) <= owner->getNumOperands() && "invalid range");
436 }
438  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
439 
440 /// Construct a new mutable range for the given OpOperand.
442  : MutableOperandRange(opOperand.getOwner(),
443  /*start=*/opOperand.getOperandNumber(),
444  /*length=*/1) {}
445 
446 /// Slice this range into a sub range, with the additional operand segment.
448 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
449  std::optional<OperandSegment> segment) const {
450  assert((subStart + subLen) <= length && "invalid sub-range");
451  MutableOperandRange subSlice(owner, start + subStart, subLen,
452  operandSegments);
453  if (segment)
454  subSlice.operandSegments.push_back(*segment);
455  return subSlice;
456 }
457 
458 /// Append the given values to the range.
460  if (values.empty())
461  return;
462  owner->insertOperands(start + length, values);
463  updateLength(length + values.size());
464 }
465 
466 /// Assign this range to the given values.
468  owner->setOperands(start, length, values);
469  if (length != values.size())
470  updateLength(/*newLength=*/values.size());
471 }
472 
473 /// Assign the range to the given value.
475  if (length == 1) {
476  owner->setOperand(start, value);
477  } else {
478  owner->setOperands(start, length, value);
479  updateLength(/*newLength=*/1);
480  }
481 }
482 
483 /// Erase the operands within the given sub-range.
484 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
485  assert((subStart + subLen) <= length && "invalid sub-range");
486  if (length == 0)
487  return;
488  owner->eraseOperands(start + subStart, subLen);
489  updateLength(length - subLen);
490 }
491 
492 /// Clear this range and erase all of the operands.
494  if (length != 0) {
495  owner->eraseOperands(start, length);
496  updateLength(/*newLength=*/0);
497  }
498 }
499 
500 /// Allow implicit conversion to an OperandRange.
501 MutableOperandRange::operator OperandRange() const {
502  return owner->getOperands().slice(start, length);
503 }
504 
505 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
506  return owner->getOpOperands().slice(start, length);
507 }
508 
511  return MutableOperandRangeRange(*this, segmentSizes);
512 }
513 
514 /// Update the length of this range to the one provided.
515 void MutableOperandRange::updateLength(unsigned newLength) {
516  int32_t diff = int32_t(newLength) - int32_t(length);
517  length = newLength;
518 
519  // Update any of the provided segment attributes.
520  for (OperandSegment &segment : operandSegments) {
521  auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
522  SmallVector<int32_t, 8> segments(attr.asArrayRef());
523  segments[segment.first] += diff;
524  segment.second.setValue(
525  DenseI32ArrayAttr::get(attr.getContext(), segments));
526  owner->setAttr(segment.second.getName(), segment.second.getValue());
527  }
528 }
529 
531  assert(index < length && "index is out of bounds");
532  return owner->getOpOperand(start + index);
533 }
534 
536  return owner->getOpOperands().slice(start, length).begin();
537 }
538 
540  return owner->getOpOperands().slice(start, length).end();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // MutableOperandRangeRange
545 
547  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
549  OwnerT(operands, operandSegmentAttr), 0,
550  llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
551 }
552 
554  return getBase().first;
555 }
556 
557 MutableOperandRangeRange::operator OperandRangeRange() const {
558  return OperandRangeRange(getBase().first, getBase().second.getValue());
559 }
560 
561 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
562  ptrdiff_t index) {
563  ArrayRef<int32_t> sizeData =
564  llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
565  uint32_t startIndex =
566  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
567  return object.first.slice(
568  startIndex, *(sizeData.begin() + index),
569  MutableOperandRange::OperandSegment(index, object.second));
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // ResultRange
574 
576  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
577  1) {}
578 
580  return {use_begin(), use_end()};
581 }
583  return use_iterator(*this);
584 }
586  return use_iterator(*this, /*end=*/true);
587 }
589  return {user_begin(), user_end()};
590 }
592  return user_iterator(use_begin());
593 }
595  return user_iterator(use_end());
596 }
597 
599  : it(end ? results.end() : results.begin()), endIt(results.end()) {
600  // Only initialize current use if there are results/can be uses.
601  if (it != endIt)
602  skipOverResultsWithNoUsers();
603 }
604 
606  // We increment over uses, if we reach the last use then move to next
607  // result.
608  if (use != (*it).use_end())
609  ++use;
610  if (use == (*it).use_end()) {
611  ++it;
612  skipOverResultsWithNoUsers();
613  }
614  return *this;
615 }
616 
617 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
618  while (it != endIt && (*it).use_empty())
619  ++it;
620 
621  // If we are at the last result, then set use to first use of
622  // first result (sentinel value used for end).
623  if (it == endIt)
624  use = {};
625  else
626  use = (*it).use_begin();
627 }
628 
631 }
632 
634  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
635  replaceUsesWithIf(op->getResults(), shouldReplace);
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // ValueRange
640 
642  : ValueRange(values.data(), values.size()) {}
644  : ValueRange(values.begin().getBase(), values.size()) {}
646  : ValueRange(values.getBase(), values.size()) {}
647 
648 /// See `llvm::detail::indexed_accessor_range_base` for details.
649 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
650  ptrdiff_t index) {
651  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
652  return {value + index};
653  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
654  return {operand + index};
655  return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
656 }
657 /// See `llvm::detail::indexed_accessor_range_base` for details.
658 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
659  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
660  return value[index];
661  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
662  return operand[index].get();
663  return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // Operation Equivalency
668 //===----------------------------------------------------------------------===//
669 
671  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
672  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
673  // Hash operations based upon their:
674  // - Operation Name
675  // - Attributes
676  // - Result Types
677  llvm::hash_code hash =
678  llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
679  op->getResultTypes(), op->hashProperties());
680 
681  // - Location if required
682  if (!(flags & Flags::IgnoreLocations))
683  hash = llvm::hash_combine(hash, op->getLoc());
684 
685  // - Operands
686  for (Value operand : op->getOperands())
687  hash = llvm::hash_combine(hash, hashOperands(operand));
688 
689  // - Results
690  for (Value result : op->getResults())
691  hash = llvm::hash_combine(hash, hashResults(result));
692  return hash;
693 }
694 
696  Region *lhs, Region *rhs,
697  function_ref<LogicalResult(Value, Value)> checkEquivalent,
698  function_ref<void(Value, Value)> markEquivalent,
700  DenseMap<Block *, Block *> blocksMap;
701  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
702  // Check block arguments.
703  if (lBlock.getNumArguments() != rBlock.getNumArguments())
704  return false;
705 
706  // Map the two blocks.
707  auto insertion = blocksMap.insert({&lBlock, &rBlock});
708  if (insertion.first->getSecond() != &rBlock)
709  return false;
710 
711  for (auto argPair :
712  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
713  Value curArg = std::get<0>(argPair);
714  Value otherArg = std::get<1>(argPair);
715  if (curArg.getType() != otherArg.getType())
716  return false;
717  if (!(flags & OperationEquivalence::IgnoreLocations) &&
718  curArg.getLoc() != otherArg.getLoc())
719  return false;
720  // Corresponding bbArgs are equivalent.
721  if (markEquivalent)
722  markEquivalent(curArg, otherArg);
723  }
724 
725  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
726  // Check for op equality (recursively).
727  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
728  markEquivalent, flags))
729  return false;
730  // Check successor mapping.
731  for (auto successorsPair :
732  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
733  Block *curSuccessor = std::get<0>(successorsPair);
734  Block *otherSuccessor = std::get<1>(successorsPair);
735  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
736  if (insertion.first->getSecond() != otherSuccessor)
737  return false;
738  }
739  return true;
740  };
741  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
742  };
743  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
744 }
745 
746 // Value equivalence cache to be used with `isRegionEquivalentTo` and
747 // `isEquivalentTo`.
750  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
751  return success(lhsValue == rhsValue ||
752  equivalentValues.lookup(lhsValue) == rhsValue);
753  }
754  void markEquivalent(Value lhsResult, Value rhsResult) {
755  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
756  // Make sure that the value was not already marked equivalent to some other
757  // value.
758  (void)insertion;
759  assert(insertion.first->second == rhsResult &&
760  "inconsistent OperationEquivalence state");
761  }
762 };
763 
764 /*static*/ bool
767  ValueEquivalenceCache cache;
768  return isRegionEquivalentTo(
769  lhs, rhs,
770  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
771  return cache.checkEquivalent(lhsValue, rhsValue);
772  },
773  [&](Value lhsResult, Value rhsResult) {
774  cache.markEquivalent(lhsResult, rhsResult);
775  },
776  flags);
777 }
778 
780  Operation *lhs, Operation *rhs,
781  function_ref<LogicalResult(Value, Value)> checkEquivalent,
782  function_ref<void(Value, Value)> markEquivalent, Flags flags) {
783  if (lhs == rhs)
784  return true;
785 
786  // 1. Compare the operation properties.
787  if (lhs->getName() != rhs->getName() ||
790  lhs->getNumRegions() != rhs->getNumRegions() ||
791  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
792  lhs->getNumOperands() != rhs->getNumOperands() ||
793  lhs->getNumResults() != rhs->getNumResults() ||
795  rhs->getPropertiesStorage()))
796  return false;
797  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
798  return false;
799 
800  // 2. Compare operands.
801  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
802  Value curArg = std::get<0>(operandPair);
803  Value otherArg = std::get<1>(operandPair);
804  if (curArg == otherArg)
805  continue;
806  if (curArg.getType() != otherArg.getType())
807  return false;
808  if (failed(checkEquivalent(curArg, otherArg)))
809  return false;
810  }
811 
812  // 3. Compare result types and mark results as equivalent.
813  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
814  Value curArg = std::get<0>(resultPair);
815  Value otherArg = std::get<1>(resultPair);
816  if (curArg.getType() != otherArg.getType())
817  return false;
818  if (markEquivalent)
819  markEquivalent(curArg, otherArg);
820  }
821 
822  // 4. Compare regions.
823  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
824  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
825  &std::get<1>(regionPair), checkEquivalent,
826  markEquivalent, flags))
827  return false;
828 
829  return true;
830 }
831 
833  Operation *rhs,
834  Flags flags) {
835  ValueEquivalenceCache cache;
837  lhs, rhs,
838  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
839  return cache.checkEquivalent(lhsValue, rhsValue);
840  },
841  [&](Value lhsResult, Value rhsResult) {
842  cache.markEquivalent(lhsResult, rhsResult);
843  },
844  flags);
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // OperationFingerPrint
849 //===----------------------------------------------------------------------===//
850 
851 template <typename T>
852 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
853  hasher.update(
854  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
855 }
856 
858  llvm::SHA1 hasher;
859 
860  // Hash each of the operations based upon their mutable bits:
861  topOp->walk([&](Operation *op) {
862  // - Operation pointer
863  addDataToHash(hasher, op);
864  // - Parent operation pointer (to take into account the nesting structure)
865  if (op != topOp)
866  addDataToHash(hasher, op->getParentOp());
867  // - Attributes
869  // - Properties
870  addDataToHash(hasher, op->hashProperties());
871  // - Blocks in Regions
872  for (Region &region : op->getRegions()) {
873  for (Block &block : region) {
874  addDataToHash(hasher, &block);
875  for (BlockArgument arg : block.getArguments())
876  addDataToHash(hasher, arg);
877  }
878  }
879  // - Location
880  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
881  // - Operands
882  for (Value operand : op->getOperands())
883  addDataToHash(hasher, operand);
884  // - Successors
885  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
886  addDataToHash(hasher, op->getSuccessor(i));
887  // - Result types
888  for (Type t : op->getResultTypes())
889  addDataToHash(hasher, t);
890  });
891  hash = hasher.result();
892 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:315
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:121
BlockArgListType getArguments()
Definition: Block.h:80
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:104
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:203
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments=std::nullopt)
Construct a new mutable range from the given operand, operand start index, and range length.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:120
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:82
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp)
bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:255
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
Block * getSuccessor(unsigned index)
Definition: Operation.h:687
unsigned getNumSuccessors()
Definition: Operation.h:685
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:355
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:483
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
Definition: Operation.cpp:354
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
SuccessorRange getSuccessors()
Definition: Operation.h:682
result_range getResults()
Definition: Operation.h:410
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
Definition: Operation.cpp:369
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:879
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:342
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:239
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:314
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:295
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:278
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:259
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:383
ValueRange(Arg &&arg)
Definition: ValueRange.h:391
An iterator over the users of an IRObject.
Definition: UseDefLists.h:344
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
This class provides the implementation for an operation result.
Definition: Value.h:364
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
Include the generated interface declarations.
Definition: CallGraph.h:229
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
DenseMap< Value, Value > equivalentValues
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags)
Compare two regions (including their subregions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None)
Compare two operations (including their regions) and return if they are equivalent.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const