MLIR  21.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1 //===- OperationSupport.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains out-of-line implementations of the support types that
10 // Operation and related classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/Support/SHA1.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // NamedAttrList
27 //===----------------------------------------------------------------------===//
28 
30  assign(attributes.begin(), attributes.end());
31 }
32 
33 NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34  : NamedAttrList(attributes ? attributes.getValue()
35  : ArrayRef<NamedAttribute>()) {
36  dictionarySorted.setPointerAndInt(attributes, true);
37 }
38 
40  assign(inStart, inEnd);
41 }
42 
44 
45 std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46  std::optional<NamedAttribute> duplicate =
47  DictionaryAttr::findDuplicate(attrs, isSorted());
48  // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49  // state.
50  if (!isSorted())
51  dictionarySorted.setPointerAndInt(nullptr, true);
52  return duplicate;
53 }
54 
55 DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56  if (!isSorted()) {
57  DictionaryAttr::sortInPlace(attrs);
58  dictionarySorted.setPointerAndInt(nullptr, true);
59  }
60  if (!dictionarySorted.getPointer())
61  dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62  return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63 }
64 
65 /// Replaces the attributes with new list of attributes.
67  DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
68  dictionarySorted.setPointerAndInt(nullptr, true);
69 }
70 
72  if (isSorted())
73  dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
74  dictionarySorted.setPointer(nullptr);
75  attrs.push_back(newAttribute);
76 }
77 
78 /// Return the specified attribute if present, null otherwise.
79 Attribute NamedAttrList::get(StringRef name) const {
80  auto it = findAttr(*this, name);
81  return it.second ? it.first->getValue() : Attribute();
82 }
83 Attribute NamedAttrList::get(StringAttr name) const {
84  auto it = findAttr(*this, name);
85  return it.second ? it.first->getValue() : Attribute();
86 }
87 
88 /// Return the specified named attribute if present, std::nullopt otherwise.
89 std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
90  auto it = findAttr(*this, name);
91  return it.second ? *it.first : std::optional<NamedAttribute>();
92 }
93 std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
94  auto it = findAttr(*this, name);
95  return it.second ? *it.first : std::optional<NamedAttribute>();
96 }
97 
98 /// If the an attribute exists with the specified name, change it to the new
99 /// value. Otherwise, add a new attribute with the specified name/value.
100 Attribute NamedAttrList::set(StringAttr name, Attribute value) {
101  assert(value && "attributes may never be null");
102 
103  // Look for an existing attribute with the given name, and set its value
104  // in-place. Return the previous value of the attribute, if there was one.
105  auto it = findAttr(*this, name);
106  if (it.second) {
107  // Update the existing attribute by swapping out the old value for the new
108  // value. Return the old value.
109  Attribute oldValue = it.first->getValue();
110  if (it.first->getValue() != value) {
111  it.first->setValue(value);
112 
113  // If the attributes have changed, the dictionary is invalidated.
114  dictionarySorted.setPointer(nullptr);
115  }
116  return oldValue;
117  }
118  // Perform a string lookup to insert the new attribute into its sorted
119  // position.
120  if (isSorted())
121  it = findAttr(*this, name.strref());
122  attrs.insert(it.first, {name, value});
123  // Invalidate the dictionary. Return null as there was no previous value.
124  dictionarySorted.setPointer(nullptr);
125  return Attribute();
126 }
127 
128 Attribute NamedAttrList::set(StringRef name, Attribute value) {
129  assert(value && "attributes may never be null");
130  return set(mlir::StringAttr::get(value.getContext(), name), value);
131 }
132 
133 Attribute
134 NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
135  // Erasing does not affect the sorted property.
136  Attribute attr = it->getValue();
137  attrs.erase(it);
138  dictionarySorted.setPointer(nullptr);
139  return attr;
140 }
141 
142 Attribute NamedAttrList::erase(StringAttr name) {
143  auto it = findAttr(*this, name);
144  return it.second ? eraseImpl(it.first) : Attribute();
145 }
146 
148  auto it = findAttr(*this, name);
149  return it.second ? eraseImpl(it.first) : Attribute();
150 }
151 
154  assign(rhs.begin(), rhs.end());
155  return *this;
156 }
157 
158 NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
159 
160 //===----------------------------------------------------------------------===//
161 // OperationState
162 //===----------------------------------------------------------------------===//
163 
164 OperationState::OperationState(Location location, StringRef name)
165  : location(location), name(name, location->getContext()) {}
166 
168  : location(location), name(name) {}
169 
171  ValueRange operands, TypeRange types,
172  ArrayRef<NamedAttribute> attributes,
173  BlockRange successors,
174  MutableArrayRef<std::unique_ptr<Region>> regions)
175  : location(location), name(name),
176  operands(operands.begin(), operands.end()),
177  types(types.begin(), types.end()),
178  attributes(attributes.begin(), attributes.end()),
179  successors(successors.begin(), successors.end()) {
180  for (std::unique_ptr<Region> &r : regions)
181  this->regions.push_back(std::move(r));
182 }
183 OperationState::OperationState(Location location, StringRef name,
184  ValueRange operands, TypeRange types,
185  ArrayRef<NamedAttribute> attributes,
186  BlockRange successors,
187  MutableArrayRef<std::unique_ptr<Region>> regions)
188  : OperationState(location, OperationName(name, location.getContext()),
189  operands, types, attributes, successors, regions) {}
190 
192  if (properties)
193  propertiesDeleter(properties);
194 }
195 
198  if (LLVM_UNLIKELY(propertiesAttr)) {
199  assert(!properties);
201  }
202  if (properties)
203  propertiesSetter(op->getPropertiesStorage(), properties);
204  return success();
205 }
206 
208  operands.append(newOperands.begin(), newOperands.end());
209 }
210 
212  successors.append(newSuccessors.begin(), newSuccessors.end());
213 }
214 
216  regions.emplace_back(new Region);
217  return regions.back().get();
218 }
219 
220 void OperationState::addRegion(std::unique_ptr<Region> &&region) {
221  regions.push_back(std::move(region));
222 }
223 
225  MutableArrayRef<std::unique_ptr<Region>> regions) {
226  for (std::unique_ptr<Region> &region : regions)
227  addRegion(std::move(region));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // OperandStorage
232 //===----------------------------------------------------------------------===//
233 
235  OpOperand *trailingOperands,
236  ValueRange values)
237  : isStorageDynamic(false), operandStorage(trailingOperands) {
238  numOperands = capacity = values.size();
239  for (unsigned i = 0; i < numOperands; ++i)
240  new (&operandStorage[i]) OpOperand(owner, values[i]);
241 }
242 
244  for (auto &operand : getOperands())
245  operand.~OpOperand();
246 
247  // If the storage is dynamic, deallocate it.
248  if (isStorageDynamic)
249  free(operandStorage);
250 }
251 
252 /// Replace the operands contained in the storage with the ones provided in
253 /// 'values'.
255  MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
256  for (unsigned i = 0, e = values.size(); i != e; ++i)
257  storageOperands[i].set(values[i]);
258 }
259 
260 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
261 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
262 /// than the range pointed to by 'start'+'length'.
263 void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
264  unsigned length, ValueRange operands) {
265  // If the new size is the same, we can update inplace.
266  unsigned newSize = operands.size();
267  if (newSize == length) {
268  MutableArrayRef<OpOperand> storageOperands = getOperands();
269  for (unsigned i = 0, e = length; i != e; ++i)
270  storageOperands[start + i].set(operands[i]);
271  return;
272  }
273  // If the new size is greater, remove the extra operands and set the rest
274  // inplace.
275  if (newSize < length) {
276  eraseOperands(start + operands.size(), length - newSize);
277  setOperands(owner, start, newSize, operands);
278  return;
279  }
280  // Otherwise, the new size is greater so we need to grow the storage.
281  auto storageOperands = resize(owner, size() + (newSize - length));
282 
283  // Shift operands to the right to make space for the new operands.
284  unsigned rotateSize = storageOperands.size() - (start + length);
285  auto rbegin = storageOperands.rbegin();
286  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
287 
288  // Update the operands inplace.
289  for (unsigned i = 0, e = operands.size(); i != e; ++i)
290  storageOperands[start + i].set(operands[i]);
291 }
292 
293 /// Erase an operand held by the storage.
294 void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
295  MutableArrayRef<OpOperand> operands = getOperands();
296  assert((start + length) <= operands.size());
297  numOperands -= length;
298 
299  // Shift all operands down if the operand to remove is not at the end.
300  if (start != numOperands) {
301  auto *indexIt = std::next(operands.begin(), start);
302  std::rotate(indexIt, std::next(indexIt, length), operands.end());
303  }
304  for (unsigned i = 0; i != length; ++i)
305  operands[numOperands + i].~OpOperand();
306 }
307 
308 void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
309  MutableArrayRef<OpOperand> operands = getOperands();
310  assert(eraseIndices.size() == operands.size());
311 
312  // Check that at least one operand is erased.
313  int firstErasedIndice = eraseIndices.find_first();
314  if (firstErasedIndice == -1)
315  return;
316 
317  // Shift all of the removed operands to the end, and destroy them.
318  numOperands = firstErasedIndice;
319  for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
320  if (!eraseIndices.test(i))
321  operands[numOperands++] = std::move(operands[i]);
322  for (OpOperand &operand : operands.drop_front(numOperands))
323  operand.~OpOperand();
324 }
325 
326 /// Resize the storage to the given size. Returns the array containing the new
327 /// operands.
328 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
329  unsigned newSize) {
330  // If the number of operands is less than or equal to the current amount, we
331  // can just update in place.
332  MutableArrayRef<OpOperand> origOperands = getOperands();
333  if (newSize <= numOperands) {
334  // If the number of new size is less than the current, remove any extra
335  // operands.
336  for (unsigned i = newSize; i != numOperands; ++i)
337  origOperands[i].~OpOperand();
338  numOperands = newSize;
339  return origOperands.take_front(newSize);
340  }
341 
342  // If the new size is within the original inline capacity, grow inplace.
343  if (newSize <= capacity) {
344  OpOperand *opBegin = origOperands.data();
345  for (unsigned e = newSize; numOperands != e; ++numOperands)
346  new (&opBegin[numOperands]) OpOperand(owner);
347  return MutableArrayRef<OpOperand>(opBegin, newSize);
348  }
349 
350  // Otherwise, we need to allocate a new storage.
351  unsigned newCapacity =
352  std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
353  OpOperand *newOperandStorage =
354  reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
355 
356  // Move the current operands to the new storage.
357  MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
358  std::uninitialized_move(origOperands.begin(), origOperands.end(),
359  newOperands.begin());
360 
361  // Destroy the original operands.
362  for (auto &operand : origOperands)
363  operand.~OpOperand();
364 
365  // Initialize any new operands.
366  for (unsigned e = newSize; numOperands != e; ++numOperands)
367  new (&newOperands[numOperands]) OpOperand(owner);
368 
369  // If the current storage is dynamic, free it.
370  if (isStorageDynamic)
371  free(operandStorage);
372 
373  // Update the storage representation to use the new dynamic storage.
374  operandStorage = newOperandStorage;
375  capacity = newCapacity;
376  isStorageDynamic = true;
377  return newOperands;
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // Operation Value-Iterators
382 //===----------------------------------------------------------------------===//
383 
384 //===----------------------------------------------------------------------===//
385 // OperandRange
386 
388  assert(!empty() && "range must not be empty");
389  return base->getOperandNumber();
390 }
391 
393  return OperandRangeRange(*this, segmentSizes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // OperandRangeRange
398 
400  Attribute operandSegments)
401  : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
402  llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
403 }
404 
406  const OwnerT &owner = getBase();
407  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
408  return OperandRange(owner.first,
409  std::accumulate(sizeData.begin(), sizeData.end(), 0));
410 }
411 
412 OperandRange OperandRangeRange::dereference(const OwnerT &object,
413  ptrdiff_t index) {
414  ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
415  uint32_t startIndex =
416  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
417  return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // MutableOperandRange
422 
423 /// Construct a new mutable range from the given operand, operand start index,
424 /// and range length.
426  Operation *owner, unsigned start, unsigned length,
427  ArrayRef<OperandSegment> operandSegments)
428  : owner(owner), start(start), length(length),
429  operandSegments(operandSegments) {
430  assert((start + length) <= owner->getNumOperands() && "invalid range");
431 }
433  : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
434 
435 /// Construct a new mutable range for the given OpOperand.
437  : MutableOperandRange(opOperand.getOwner(),
438  /*start=*/opOperand.getOperandNumber(),
439  /*length=*/1) {}
440 
441 /// Slice this range into a sub range, with the additional operand segment.
443 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
444  std::optional<OperandSegment> segment) const {
445  assert((subStart + subLen) <= length && "invalid sub-range");
446  MutableOperandRange subSlice(owner, start + subStart, subLen,
447  operandSegments);
448  if (segment)
449  subSlice.operandSegments.push_back(*segment);
450  return subSlice;
451 }
452 
453 /// Append the given values to the range.
455  if (values.empty())
456  return;
457  owner->insertOperands(start + length, values);
458  updateLength(length + values.size());
459 }
460 
461 /// Assign this range to the given values.
463  owner->setOperands(start, length, values);
464  if (length != values.size())
465  updateLength(/*newLength=*/values.size());
466 }
467 
468 /// Assign the range to the given value.
470  if (length == 1) {
471  owner->setOperand(start, value);
472  } else {
473  owner->setOperands(start, length, value);
474  updateLength(/*newLength=*/1);
475  }
476 }
477 
478 /// Erase the operands within the given sub-range.
479 void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
480  assert((subStart + subLen) <= length && "invalid sub-range");
481  if (length == 0)
482  return;
483  owner->eraseOperands(start + subStart, subLen);
484  updateLength(length - subLen);
485 }
486 
487 /// Clear this range and erase all of the operands.
489  if (length != 0) {
490  owner->eraseOperands(start, length);
491  updateLength(/*newLength=*/0);
492  }
493 }
494 
495 /// Explicit conversion to an OperandRange.
497  return owner->getOperands().slice(start, length);
498 }
499 
500 /// Allow implicit conversion to an OperandRange.
501 MutableOperandRange::operator OperandRange() const {
502  return getAsOperandRange();
503 }
504 
505 MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
506  return owner->getOpOperands().slice(start, length);
507 }
508 
511  return MutableOperandRangeRange(*this, segmentSizes);
512 }
513 
514 /// Update the length of this range to the one provided.
515 void MutableOperandRange::updateLength(unsigned newLength) {
516  int32_t diff = int32_t(newLength) - int32_t(length);
517  length = newLength;
518 
519  // Update any of the provided segment attributes.
520  for (OperandSegment &segment : operandSegments) {
521  auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
522  SmallVector<int32_t, 8> segments(attr.asArrayRef());
523  segments[segment.first] += diff;
524  segment.second.setValue(
525  DenseI32ArrayAttr::get(attr.getContext(), segments));
526  owner->setAttr(segment.second.getName(), segment.second.getValue());
527  }
528 }
529 
531  assert(index < length && "index is out of bounds");
532  return owner->getOpOperand(start + index);
533 }
534 
536  return owner->getOpOperands().slice(start, length).begin();
537 }
538 
540  return owner->getOpOperands().slice(start, length).end();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // MutableOperandRangeRange
545 
547  const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
549  OwnerT(operands, operandSegmentAttr), 0,
550  llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
551 }
552 
554  return getBase().first;
555 }
556 
557 MutableOperandRangeRange::operator OperandRangeRange() const {
558  return OperandRangeRange(getBase().first, getBase().second.getValue());
559 }
560 
561 MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
562  ptrdiff_t index) {
563  ArrayRef<int32_t> sizeData =
564  llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
565  uint32_t startIndex =
566  std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
567  return object.first.slice(
568  startIndex, *(sizeData.begin() + index),
569  MutableOperandRange::OperandSegment(index, object.second));
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // ResultRange
574 
576  : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
577  1) {}
578 
580  return {use_begin(), use_end()};
581 }
583  return use_iterator(*this);
584 }
586  return use_iterator(*this, /*end=*/true);
587 }
589  return {user_begin(), user_end()};
590 }
592  return user_iterator(use_begin());
593 }
595  return user_iterator(use_end());
596 }
597 
599  : it(end ? results.end() : results.begin()), endIt(results.end()) {
600  // Only initialize current use if there are results/can be uses.
601  if (it != endIt)
602  skipOverResultsWithNoUsers();
603 }
604 
606  // We increment over uses, if we reach the last use then move to next
607  // result.
608  if (use != (*it).use_end())
609  ++use;
610  if (use == (*it).use_end()) {
611  ++it;
612  skipOverResultsWithNoUsers();
613  }
614  return *this;
615 }
616 
617 void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
618  while (it != endIt && (*it).use_empty())
619  ++it;
620 
621  // If we are at the last result, then set use to first use of
622  // first result (sentinel value used for end).
623  if (it == endIt)
624  use = {};
625  else
626  use = (*it).use_begin();
627 }
628 
631 }
632 
634  Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
635  replaceUsesWithIf(op->getResults(), shouldReplace);
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // ValueRange
640 
642  : ValueRange(values.data(), values.size()) {}
644  : ValueRange(values.begin().getBase(), values.size()) {}
646  : ValueRange(values.getBase(), values.size()) {}
647 
648 /// See `llvm::detail::indexed_accessor_range_base` for details.
649 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
650  ptrdiff_t index) {
651  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
652  return {value + index};
653  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
654  return {operand + index};
655  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
656 }
657 /// See `llvm::detail::indexed_accessor_range_base` for details.
658 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
659  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
660  return value[index];
661  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
662  return operand[index].get();
663  return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // Operation Equivalency
668 //===----------------------------------------------------------------------===//
669 
671  Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
672  function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
673  // Hash operations based upon their:
674  // - Operation Name
675  // - Attributes
676  // - Result Types
677  llvm::hash_code hash =
678  llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(),
679  op->getResultTypes(), op->hashProperties());
680 
681  // - Location if required
682  if (!(flags & Flags::IgnoreLocations))
683  hash = llvm::hash_combine(hash, op->getLoc());
684 
685  // - Operands
687  op->getNumOperands() > 0) {
688  size_t operandHash = hashOperands(op->getOperand(0));
689  for (auto operand : op->getOperands().drop_front())
690  operandHash += hashOperands(operand);
691  hash = llvm::hash_combine(hash, operandHash);
692  } else {
693  for (Value operand : op->getOperands())
694  hash = llvm::hash_combine(hash, hashOperands(operand));
695  }
696 
697  // - Results
698  for (Value result : op->getResults())
699  hash = llvm::hash_combine(hash, hashResults(result));
700  return hash;
701 }
702 
704  Region *lhs, Region *rhs,
705  function_ref<LogicalResult(Value, Value)> checkEquivalent,
706  function_ref<void(Value, Value)> markEquivalent,
708  function_ref<LogicalResult(ValueRange, ValueRange)>
709  checkCommutativeEquivalent) {
710  DenseMap<Block *, Block *> blocksMap;
711  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
712  // Check block arguments.
713  if (lBlock.getNumArguments() != rBlock.getNumArguments())
714  return false;
715 
716  // Map the two blocks.
717  auto insertion = blocksMap.insert({&lBlock, &rBlock});
718  if (insertion.first->getSecond() != &rBlock)
719  return false;
720 
721  for (auto argPair :
722  llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
723  Value curArg = std::get<0>(argPair);
724  Value otherArg = std::get<1>(argPair);
725  if (curArg.getType() != otherArg.getType())
726  return false;
727  if (!(flags & OperationEquivalence::IgnoreLocations) &&
728  curArg.getLoc() != otherArg.getLoc())
729  return false;
730  // Corresponding bbArgs are equivalent.
731  if (markEquivalent)
732  markEquivalent(curArg, otherArg);
733  }
734 
735  auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
736  // Check for op equality (recursively).
737  if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
738  markEquivalent, flags,
739  checkCommutativeEquivalent))
740  return false;
741  // Check successor mapping.
742  for (auto successorsPair :
743  llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
744  Block *curSuccessor = std::get<0>(successorsPair);
745  Block *otherSuccessor = std::get<1>(successorsPair);
746  auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
747  if (insertion.first->getSecond() != otherSuccessor)
748  return false;
749  }
750  return true;
751  };
752  return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
753  };
754  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
755 }
756 
757 // Value equivalence cache to be used with `isRegionEquivalentTo` and
758 // `isEquivalentTo`.
761  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
762  return success(lhsValue == rhsValue ||
763  equivalentValues.lookup(lhsValue) == rhsValue);
764  }
765  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
766  ValueRange rhsRange) {
767  // Handle simple case where sizes mismatch.
768  if (lhsRange.size() != rhsRange.size())
769  return failure();
770 
771  // Handle where operands in order are equivalent.
772  auto lhsIt = lhsRange.begin();
773  auto rhsIt = rhsRange.begin();
774  for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
775  if (failed(checkEquivalent(*lhsIt, *rhsIt)))
776  break;
777  }
778  if (lhsIt == lhsRange.end())
779  return success();
780 
781  // Handle another simple case where operands are just a permutation.
782  // Note: This is not sufficient, this handles simple cases relatively
783  // cheaply.
784  auto sortValues = [](ValueRange values) {
785  SmallVector<Value> sortedValues = llvm::to_vector(values);
786  llvm::sort(sortedValues, [](Value a, Value b) {
787  return a.getAsOpaquePointer() < b.getAsOpaquePointer();
788  });
789  return sortedValues;
790  };
791  auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
792  auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
793  return success(lhsSorted == rhsSorted);
794  }
795  void markEquivalent(Value lhsResult, Value rhsResult) {
796  auto insertion = equivalentValues.insert({lhsResult, rhsResult});
797  // Make sure that the value was not already marked equivalent to some other
798  // value.
799  (void)insertion;
800  assert(insertion.first->second == rhsResult &&
801  "inconsistent OperationEquivalence state");
802  }
803 };
804 
805 /*static*/ bool
808  ValueEquivalenceCache cache;
809  return isRegionEquivalentTo(
810  lhs, rhs,
811  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
812  return cache.checkEquivalent(lhsValue, rhsValue);
813  },
814  [&](Value lhsResult, Value rhsResult) {
815  cache.markEquivalent(lhsResult, rhsResult);
816  },
817  flags,
818  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
819  return cache.checkCommutativeEquivalent(lhs, rhs);
820  });
821 }
822 
824  Operation *lhs, Operation *rhs,
825  function_ref<LogicalResult(Value, Value)> checkEquivalent,
826  function_ref<void(Value, Value)> markEquivalent, Flags flags,
827  function_ref<LogicalResult(ValueRange, ValueRange)>
828  checkCommutativeEquivalent) {
829  if (lhs == rhs)
830  return true;
831 
832  // 1. Compare the operation properties.
833  if (lhs->getName() != rhs->getName() ||
834  lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
835  lhs->getNumRegions() != rhs->getNumRegions() ||
836  lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
837  lhs->getNumOperands() != rhs->getNumOperands() ||
838  lhs->getNumResults() != rhs->getNumResults() ||
840  rhs->getPropertiesStorage()))
841  return false;
842  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
843  return false;
844 
845  // 2. Compare operands.
846  if (checkCommutativeEquivalent &&
848  auto lhsRange = lhs->getOperands();
849  auto rhsRange = rhs->getOperands();
850  if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
851  return false;
852  } else {
853  // Check pair wise for equivalence.
854  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
855  Value curArg = std::get<0>(operandPair);
856  Value otherArg = std::get<1>(operandPair);
857  if (curArg == otherArg)
858  continue;
859  if (curArg.getType() != otherArg.getType())
860  return false;
861  if (failed(checkEquivalent(curArg, otherArg)))
862  return false;
863  }
864  }
865 
866  // 3. Compare result types and mark results as equivalent.
867  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
868  Value curArg = std::get<0>(resultPair);
869  Value otherArg = std::get<1>(resultPair);
870  if (curArg.getType() != otherArg.getType())
871  return false;
872  if (markEquivalent)
873  markEquivalent(curArg, otherArg);
874  }
875 
876  // 4. Compare regions.
877  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
878  if (!isRegionEquivalentTo(&std::get<0>(regionPair),
879  &std::get<1>(regionPair), checkEquivalent,
880  markEquivalent, flags))
881  return false;
882 
883  return true;
884 }
885 
887  Operation *rhs,
888  Flags flags) {
889  ValueEquivalenceCache cache;
891  lhs, rhs,
892  [&](Value lhsValue, Value rhsValue) -> LogicalResult {
893  return cache.checkEquivalent(lhsValue, rhsValue);
894  },
895  [&](Value lhsResult, Value rhsResult) {
896  cache.markEquivalent(lhsResult, rhsResult);
897  },
898  flags,
899  [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
900  return cache.checkCommutativeEquivalent(lhs, rhs);
901  });
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // OperationFingerPrint
906 //===----------------------------------------------------------------------===//
907 
908 template <typename T>
909 static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
910  hasher.update(
911  ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
912 }
913 
915  bool includeNested) {
916  llvm::SHA1 hasher;
917 
918  // Helper function that hashes an operation based on its mutable bits:
919  auto addOperationToHash = [&](Operation *op) {
920  // - Operation pointer
921  addDataToHash(hasher, op);
922  // - Parent operation pointer (to take into account the nesting structure)
923  if (op != topOp)
924  addDataToHash(hasher, op->getParentOp());
925  // - Attributes
926  addDataToHash(hasher, op->getRawDictionaryAttrs());
927  // - Properties
928  addDataToHash(hasher, op->hashProperties());
929  // - Blocks in Regions
930  for (Region &region : op->getRegions()) {
931  for (Block &block : region) {
932  addDataToHash(hasher, &block);
933  for (BlockArgument arg : block.getArguments())
934  addDataToHash(hasher, arg);
935  }
936  }
937  // - Location
938  addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
939  // - Operands
940  for (Value operand : op->getOperands())
941  addDataToHash(hasher, operand);
942  // - Successors
943  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
944  addDataToHash(hasher, op->getSuccessor(i));
945  // - Result types
946  for (Type t : op->getResultTypes())
947  addDataToHash(hasher, t);
948  };
949 
950  if (includeNested)
951  topOp->walk(addOperationToHash);
952  else
953  addOperationToHash(topOp);
954 
955  hash = hasher.result();
956 }
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static MLIRContext * getContext(OpFoldResult val)
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:319
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
BlockArgListType getArguments()
Definition: Block.h:87
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Location.h:110
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a contiguous range of mutable operand ranges, e.g.
Definition: ValueRange.h:206
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments=std::nullopt)
Construct a new mutable range from the given operand, operand start index, and range length.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition: ValueRange.h:120
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:82
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
bool compareOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
Block * getSuccessor(unsigned index)
Definition: Operation.h:709
unsigned getNumSuccessors()
Definition: Operation.h:707
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition: Operation.h:360
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition: Operation.h:509
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
Definition: Operation.cpp:355
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
SuccessorRange getSuccessors()
Definition: Operation.h:704
result_range getResults()
Definition: Operation.h:415
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
Definition: Operation.cpp:370
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements a use iterator for a range of operation results.
Definition: ValueRange.h:345
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition: ValueRange.h:317
ResultRange(OpResult result)
use_iterator use_end() const
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition: ValueRange.h:298
user_iterator user_end()
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:281
user_iterator user_begin()
UseIterator use_iterator
Definition: ValueRange.h:262
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition: ValueRange.h:386
ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
Definition: ValueRange.h:394
An iterator over the users of an IRObject.
Definition: UseDefLists.h:344
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:243
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addRegions(MutableArrayRef< std::unique_ptr< Region >> regions)
Take ownership of a set of regions that should be attached to the Operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const