MLIR 23.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1//===- OperationSupport.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains out-of-line implementations of the support types that
10// Operation and related classes build on top of.
11//
12//===----------------------------------------------------------------------===//
13
18#include "llvm/Support/SHA1.h"
19#include <numeric>
20#include <optional>
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// NamedAttrList
26//===----------------------------------------------------------------------===//
27
29 assign(attributes.begin(), attributes.end());
30}
31
32NamedAttrList::NamedAttrList(DictionaryAttr attributes)
33 : NamedAttrList(attributes ? attributes.getValue()
35 dictionarySorted.setPointerAndInt(attributes, true);
36}
37
39 assign(inStart, inEnd);
40}
41
43
44std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
45 std::optional<NamedAttribute> duplicate =
46 DictionaryAttr::findDuplicate(attrs, isSorted());
47 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
48 // state.
49 if (!isSorted())
50 dictionarySorted.setPointerAndInt(nullptr, true);
51 return duplicate;
52}
53
54DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
55 if (!isSorted()) {
56 DictionaryAttr::sortInPlace(attrs);
57 dictionarySorted.setPointerAndInt(nullptr, true);
58 }
59 if (!dictionarySorted.getPointer())
60 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
61 return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
62}
63
64/// Replaces the attributes with new list of attributes.
66 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
67 dictionarySorted.setPointerAndInt(nullptr, true);
68}
69
71 if (isSorted())
72 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
73 dictionarySorted.setPointer(nullptr);
74 attrs.push_back(newAttribute);
75}
76
77/// Return the specified attribute if present, null otherwise.
78Attribute NamedAttrList::get(StringRef name) const {
79 auto it = findAttr(*this, name);
80 return it.second ? it.first->getValue() : Attribute();
81}
82Attribute NamedAttrList::get(StringAttr name) const {
83 auto it = findAttr(*this, name);
84 return it.second ? it.first->getValue() : Attribute();
85}
86
87/// Return the specified named attribute if present, std::nullopt otherwise.
88std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
89 auto it = findAttr(*this, name);
90 return it.second ? *it.first : std::optional<NamedAttribute>();
91}
92std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
93 auto it = findAttr(*this, name);
94 return it.second ? *it.first : std::optional<NamedAttribute>();
95}
96
97/// If the an attribute exists with the specified name, change it to the new
98/// value. Otherwise, add a new attribute with the specified name/value.
99Attribute NamedAttrList::set(StringAttr name, Attribute value) {
100 assert(value && "attributes may never be null");
101
102 // Look for an existing attribute with the given name, and set its value
103 // in-place. Return the previous value of the attribute, if there was one.
104 auto it = findAttr(*this, name);
105 if (it.second) {
106 // Update the existing attribute by swapping out the old value for the new
107 // value. Return the old value.
108 Attribute oldValue = it.first->getValue();
109 if (it.first->getValue() != value) {
110 it.first->setValue(value);
111
112 // If the attributes have changed, the dictionary is invalidated.
113 dictionarySorted.setPointer(nullptr);
114 }
115 return oldValue;
116 }
117 // Perform a string lookup to insert the new attribute into its sorted
118 // position.
119 if (isSorted())
120 it = findAttr(*this, name.strref());
121 attrs.insert(it.first, {name, value});
122 // Invalidate the dictionary. Return null as there was no previous value.
123 dictionarySorted.setPointer(nullptr);
124 return Attribute();
125}
126
127Attribute NamedAttrList::set(StringRef name, Attribute value) {
128 assert(value && "attributes may never be null");
129 return set(mlir::StringAttr::get(value.getContext(), name), value);
130}
131
133NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
134 // Erasing does not affect the sorted property.
135 Attribute attr = it->getValue();
136 attrs.erase(it);
137 dictionarySorted.setPointer(nullptr);
138 return attr;
139}
140
142 auto it = findAttr(*this, name);
143 return it.second ? eraseImpl(it.first) : Attribute();
144}
145
147 auto it = findAttr(*this, name);
148 return it.second ? eraseImpl(it.first) : Attribute();
149}
150
153 assign(rhs.begin(), rhs.end());
154 return *this;
155}
156
157NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
158
159//===----------------------------------------------------------------------===//
160// OperationState
161//===----------------------------------------------------------------------===//
162
165
168
173 MutableArrayRef<std::unique_ptr<Region>> regions)
175 operands(operands.begin(), operands.end()),
176 types(types.begin(), types.end()),
177 attributes(attributes.begin(), attributes.end()),
178 successors(successors.begin(), successors.end()) {
179 for (std::unique_ptr<Region> &r : regions)
180 this->regions.push_back(std::move(r));
181}
189
191 if (properties)
192 propertiesDeleter(properties);
193}
194
197 if (LLVM_UNLIKELY(propertiesAttr)) {
198 assert(!properties);
200 }
201 if (properties)
202 propertiesSetter(op->getPropertiesStorage(), properties);
203 return success();
204}
205
207 operands.append(newOperands.begin(), newOperands.end());
208}
209
211 successors.append(newSuccessors.begin(), newSuccessors.end());
212}
213
215 regions.emplace_back(new Region);
216 return regions.back().get();
217}
218
219void OperationState::addRegion(std::unique_ptr<Region> &&region) {
220 regions.push_back(std::move(region));
221}
222
224 MutableArrayRef<std::unique_ptr<Region>> regions) {
225 for (std::unique_ptr<Region> &region : regions)
226 addRegion(std::move(region));
227}
228
229//===----------------------------------------------------------------------===//
230// OperandStorage
231//===----------------------------------------------------------------------===//
232
234 OpOperand *trailingOperands,
235 ValueRange values)
236 : isStorageDynamic(false), operandStorage(trailingOperands) {
237 numOperands = capacity = values.size();
238 for (unsigned i = 0; i < numOperands; ++i)
239 new (&operandStorage[i]) OpOperand(owner, values[i]);
240}
241
243 for (auto &operand : getOperands())
244 operand.~OpOperand();
245
246 // If the storage is dynamic, deallocate it.
247 if (isStorageDynamic)
248 free(operandStorage);
249}
250
251/// Replace the operands contained in the storage with the ones provided in
252/// 'values'.
254 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
255 for (unsigned i = 0, e = values.size(); i != e; ++i)
256 storageOperands[i].set(values[i]);
257}
258
259/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
260/// with the ones provided in 'operands'. 'operands' may be smaller or larger
261/// than the range pointed to by 'start'+'length'.
263 unsigned length, ValueRange operands) {
264 // If the new size is the same, we can update inplace.
265 unsigned newSize = operands.size();
266 if (newSize == length) {
267 MutableArrayRef<OpOperand> storageOperands = getOperands();
268 for (unsigned i = 0, e = length; i != e; ++i)
269 storageOperands[start + i].set(operands[i]);
270 return;
271 }
272 // If the new size is greater, remove the extra operands and set the rest
273 // inplace.
274 if (newSize < length) {
275 eraseOperands(start + operands.size(), length - newSize);
276 setOperands(owner, start, newSize, operands);
277 return;
278 }
279 // Otherwise, the new size is greater so we need to grow the storage.
280 auto storageOperands = resize(owner, size() + (newSize - length));
281
282 // Shift operands to the right to make space for the new operands.
283 unsigned rotateSize = storageOperands.size() - (start + length);
284 auto rbegin = storageOperands.rbegin();
285 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
286
287 // Update the operands inplace.
288 for (unsigned i = 0, e = operands.size(); i != e; ++i)
289 storageOperands[start + i].set(operands[i]);
290}
291
292/// Erase an operand held by the storage.
293void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
295 assert((start + length) <= operands.size());
296 numOperands -= length;
297
298 // Shift all operands down if the operand to remove is not at the end.
299 if (start != numOperands) {
300 auto *indexIt = std::next(operands.begin(), start);
301 std::rotate(indexIt, std::next(indexIt, length), operands.end());
302 }
303 for (unsigned i = 0; i != length; ++i)
304 operands[numOperands + i].~OpOperand();
305}
306
307void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
309 assert(eraseIndices.size() == operands.size());
310
311 // Check that at least one operand is erased.
312 int firstErasedIndice = eraseIndices.find_first();
313 if (firstErasedIndice == -1)
314 return;
315
316 // Shift all of the removed operands to the end, and destroy them.
317 numOperands = firstErasedIndice;
318 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
319 if (!eraseIndices.test(i))
320 operands[numOperands++] = std::move(operands[i]);
321 for (OpOperand &operand : operands.drop_front(numOperands))
322 operand.~OpOperand();
323}
324
325/// Resize the storage to the given size. Returns the array containing the new
326/// operands.
327MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
328 unsigned newSize) {
329 // If the number of operands is less than or equal to the current amount, we
330 // can just update in place.
331 MutableArrayRef<OpOperand> origOperands = getOperands();
332 if (newSize <= numOperands) {
333 // If the number of new size is less than the current, remove any extra
334 // operands.
335 for (unsigned i = newSize; i != numOperands; ++i)
336 origOperands[i].~OpOperand();
337 numOperands = newSize;
338 return origOperands.take_front(newSize);
339 }
340
341 // If the new size is within the original inline capacity, grow inplace.
342 if (newSize <= capacity) {
343 OpOperand *opBegin = origOperands.data();
344 for (unsigned e = newSize; numOperands != e; ++numOperands)
345 new (&opBegin[numOperands]) OpOperand(owner);
346 return MutableArrayRef<OpOperand>(opBegin, newSize);
347 }
348
349 // Otherwise, we need to allocate a new storage.
350 unsigned newCapacity =
351 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
352 OpOperand *newOperandStorage =
353 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
354
355 // Move the current operands to the new storage.
356 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
357 std::uninitialized_move(origOperands.begin(), origOperands.end(),
358 newOperands.begin());
359
360 // Destroy the original operands.
361 for (auto &operand : origOperands)
362 operand.~OpOperand();
363
364 // Initialize any new operands.
365 for (unsigned e = newSize; numOperands != e; ++numOperands)
366 new (&newOperands[numOperands]) OpOperand(owner);
367
368 // If the current storage is dynamic, free it.
369 if (isStorageDynamic)
370 free(operandStorage);
371
372 // Update the storage representation to use the new dynamic storage.
373 operandStorage = newOperandStorage;
374 capacity = newCapacity;
375 isStorageDynamic = true;
376 return newOperands;
377}
378
379//===----------------------------------------------------------------------===//
380// Operation Value-Iterators
381//===----------------------------------------------------------------------===//
382
383//===----------------------------------------------------------------------===//
384// OperandRange
385//===----------------------------------------------------------------------===//
386
388 assert(!empty() && "range must not be empty");
389 return base->getOperandNumber();
390}
391
393 return OperandRangeRange(*this, segmentSizes);
394}
395
396//===----------------------------------------------------------------------===//
397// OperandRangeRange
398//===----------------------------------------------------------------------===//
399
401 Attribute operandSegments)
402 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
403 llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
404}
405
407 const OwnerT &owner = getBase();
408 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
409 return OperandRange(owner.first, llvm::sum_of(sizeData));
410}
411
412OperandRange OperandRangeRange::dereference(const OwnerT &object,
413 ptrdiff_t index) {
414 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
415 uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
416 return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
417}
418
419//===----------------------------------------------------------------------===//
420// MutableOperandRange
421//===----------------------------------------------------------------------===//
422
423/// Construct a new mutable range from the given operand, operand start index,
424/// and range length.
426 Operation *owner, unsigned start, unsigned length,
427 ArrayRef<OperandSegment> operandSegments)
428 : owner(owner), start(start), length(length),
429 operandSegments(operandSegments) {
430 assert((start + length) <= owner->getNumOperands() && "invalid range");
431}
433 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
434
435/// Construct a new mutable range for the given OpOperand.
437 : MutableOperandRange(opOperand.getOwner(),
438 /*start=*/opOperand.getOperandNumber(),
439 /*length=*/1) {}
440
441/// Slice this range into a sub range, with the additional operand segment.
443MutableOperandRange::slice(unsigned subStart, unsigned subLen,
444 std::optional<OperandSegment> segment) const {
445 assert((subStart + subLen) <= length && "invalid sub-range");
446 MutableOperandRange subSlice(owner, start + subStart, subLen,
447 operandSegments);
448 if (segment)
449 subSlice.operandSegments.push_back(*segment);
450 return subSlice;
451}
452
453/// Append the given values to the range.
455 if (values.empty())
456 return;
457 owner->insertOperands(start + length, values);
458 updateLength(length + values.size());
459}
460
461/// Assign this range to the given values.
463 owner->setOperands(start, length, values);
464 if (length != values.size())
465 updateLength(/*newLength=*/values.size());
466}
467
468/// Assign the range to the given value.
470 if (length == 1) {
471 owner->setOperand(start, value);
472 } else {
473 owner->setOperands(start, length, value);
474 updateLength(/*newLength=*/1);
475 }
476}
477
478/// Erase the operands within the given sub-range.
479void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
480 assert((subStart + subLen) <= length && "invalid sub-range");
481 if (length == 0)
482 return;
483 owner->eraseOperands(start + subStart, subLen);
484 updateLength(length - subLen);
485}
486
487/// Clear this range and erase all of the operands.
489 if (length != 0) {
490 owner->eraseOperands(start, length);
491 updateLength(/*newLength=*/0);
492 }
493}
494
495/// Explicit conversion to an OperandRange.
497 return owner->getOperands().slice(start, length);
498}
499
500/// Allow implicit conversion to an OperandRange.
501MutableOperandRange::operator OperandRange() const {
502 return getAsOperandRange();
503}
504
505MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
506 return owner->getOpOperands().slice(start, length);
507}
508
511 return MutableOperandRangeRange(*this, segmentSizes);
512}
513
514/// Update the length of this range to the one provided.
515void MutableOperandRange::updateLength(unsigned newLength) {
516 int32_t diff = int32_t(newLength) - int32_t(length);
517 length = newLength;
518
519 // Update any of the provided segment attributes.
520 for (OperandSegment &segment : operandSegments) {
521 auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
522 SmallVector<int32_t, 8> segments(attr.asArrayRef());
523 segments[segment.first] += diff;
524 segment.second.setValue(
525 DenseI32ArrayAttr::get(attr.getContext(), segments));
526 owner->setAttr(segment.second.getName(), segment.second.getValue());
527 }
528}
529
531 assert(index < length && "index is out of bounds");
532 return owner->getOpOperand(start + index);
533}
534
536 return owner->getOpOperands().slice(start, length).begin();
537}
538
540 return owner->getOpOperands().slice(start, length).end();
541}
542
543//===----------------------------------------------------------------------===//
544// MutableOperandRangeRange
545//===----------------------------------------------------------------------===//
546
548 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
550 OwnerT(operands, operandSegmentAttr), 0,
551 llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
552}
553
557
558MutableOperandRangeRange::operator OperandRangeRange() const {
559 return OperandRangeRange(getBase().first, getBase().second.getValue());
560}
561
562MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
563 ptrdiff_t index) {
564 ArrayRef<int32_t> sizeData =
565 llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
566 uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
567 return object.first.slice(
568 startIndex, *(sizeData.begin() + index),
570}
571
572//===----------------------------------------------------------------------===//
573// ResultRange
574//===----------------------------------------------------------------------===//
575
577 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
578 1) {}
579
587 return use_iterator(*this, /*end=*/true);
588}
598
600 : it(end ? results.end() : results.begin()), endIt(results.end()) {
601 // Only initialize current use if there are results/can be uses.
602 if (it != endIt)
603 skipOverResultsWithNoUsers();
604}
605
607 // We increment over uses, if we reach the last use then move to next
608 // result.
609 if (use != (*it).use_end())
610 ++use;
611 if (use == (*it).use_end()) {
612 ++it;
613 skipOverResultsWithNoUsers();
614 }
615 return *this;
616}
617
618void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
619 while (it != endIt && (*it).use_empty())
620 ++it;
621
622 // If we are at the last result, then set use to first use of
623 // first result (sentinel value used for end).
624 if (it == endIt)
625 use = {};
626 else
627 use = (*it).use_begin();
628}
629
633
635 Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
636 replaceUsesWithIf(op->getResults(), shouldReplace);
637}
638
639//===----------------------------------------------------------------------===//
640// ValueRange
641//===----------------------------------------------------------------------===//
642
644 : ValueRange(values.data(), values.size()) {}
646 : ValueRange(values.begin().getBase(), values.size()) {}
648 : ValueRange(values.getBase(), values.size()) {}
649
650/// See `llvm::detail::indexed_accessor_range_base` for details.
651ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
652 ptrdiff_t index) {
653 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
654 return {value + index};
655 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
656 return {operand + index};
657 // All elements are identical; the owner pointer never advances.
658 if (llvm::isa<const Repeated<Value> *>(owner))
659 return owner;
660 return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
661}
662/// See `llvm::detail::indexed_accessor_range_base` for details.
663Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
664 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
665 return value[index];
666 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
667 return operand[index].get();
668 if (auto *repeated =
669 llvm::dyn_cast_if_present<const Repeated<Value> *>(owner))
670 return repeated->value();
671 return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
672}
673
674//===----------------------------------------------------------------------===//
675// Operation Equivalency
676//===----------------------------------------------------------------------===//
677
679 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
680 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
681 // Hash operations based upon their:
682 // - Operation Name
683 // - Attributes
684 // - Result Types
685 DictionaryAttr dictAttrs;
686 if (!(flags & Flags::IgnoreDiscardableAttrs))
687 dictAttrs = op->getRawDictionaryAttrs();
688 llvm::hash_code hash =
689 llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
690 if (!(flags & Flags::IgnoreProperties))
691 hash = llvm::hash_combine(hash, op->hashProperties());
692
693 // - Location if required
694 if (!(flags & Flags::IgnoreLocations))
695 hash = llvm::hash_combine(hash, op->getLoc());
696
697 // - Operands
698 if (!(flags & Flags::IgnoreCommutativity) &&
700 op->getNumOperands() > 0) {
701 size_t operandHash = hashOperands(op->getOperand(0));
702 for (auto operand : op->getOperands().drop_front())
703 operandHash += hashOperands(operand);
704 hash = llvm::hash_combine(hash, operandHash);
705 } else {
706 for (Value operand : op->getOperands())
707 hash = llvm::hash_combine(hash, hashOperands(operand));
708 }
709
710 // - Results
711 for (Value result : op->getResults())
712 hash = llvm::hash_combine(hash, hashResults(result));
713 return hash;
714}
715
717 Region *lhs, Region *rhs,
718 function_ref<LogicalResult(Value, Value)> checkEquivalent,
719 function_ref<void(Value, Value)> markEquivalent,
721 function_ref<LogicalResult(ValueRange, ValueRange)>
722 checkCommutativeEquivalent) {
724 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
725 // Check block arguments.
726 if (lBlock.getNumArguments() != rBlock.getNumArguments())
727 return false;
728
729 // Map the two blocks.
730 auto insertion = blocksMap.insert({&lBlock, &rBlock});
731 if (insertion.first->getSecond() != &rBlock)
732 return false;
733
734 for (auto argPair :
735 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
736 Value curArg = std::get<0>(argPair);
737 Value otherArg = std::get<1>(argPair);
738 if (curArg.getType() != otherArg.getType())
739 return false;
741 curArg.getLoc() != otherArg.getLoc())
742 return false;
743 // Corresponding bbArgs are equivalent.
744 if (markEquivalent)
745 markEquivalent(curArg, otherArg);
746 }
747
748 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
749 // Check for op equality (recursively).
750 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
751 markEquivalent, flags,
752 checkCommutativeEquivalent))
753 return false;
754 // Check successor mapping.
755 for (auto successorsPair :
756 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
757 Block *curSuccessor = std::get<0>(successorsPair);
758 Block *otherSuccessor = std::get<1>(successorsPair);
759 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
760 if (insertion.first->getSecond() != otherSuccessor)
761 return false;
762 }
763 return true;
764 };
765 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
766 };
767 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
768}
769
770// Value equivalence cache to be used with `isRegionEquivalentTo` and
771// `isEquivalentTo`.
774 LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
775 return success(lhsValue == rhsValue ||
776 equivalentValues.lookup(lhsValue) == rhsValue);
777 }
779 ValueRange rhsRange) {
780 // Handle simple case where sizes mismatch.
781 if (lhsRange.size() != rhsRange.size())
782 return failure();
783
784 // Handle where operands in order are equivalent.
785 auto lhsIt = lhsRange.begin();
786 auto rhsIt = rhsRange.begin();
787 for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
788 if (failed(checkEquivalent(*lhsIt, *rhsIt)))
789 break;
790 }
791 if (lhsIt == lhsRange.end())
792 return success();
793
794 // Handle another simple case where operands are just a permutation.
795 // Note: This is not sufficient, this handles simple cases relatively
796 // cheaply.
797 auto sortValues = [](ValueRange values) {
798 SmallVector<Value> sortedValues = llvm::to_vector(values);
799 llvm::sort(sortedValues, [](Value a, Value b) {
800 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
801 });
802 return sortedValues;
803 };
804 auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
805 auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
806 return success(lhsSorted == rhsSorted);
807 }
808 void markEquivalent(Value lhsResult, Value rhsResult) {
809 auto insertion = equivalentValues.insert({lhsResult, rhsResult});
810 // Make sure that the value was not already marked equivalent to some other
811 // value.
812 (void)insertion;
813 assert(insertion.first->second == rhsResult &&
814 "inconsistent OperationEquivalence state");
815 }
816};
817
818/*static*/ bool
823 lhs, rhs,
824 [&](Value lhsValue, Value rhsValue) -> LogicalResult {
825 return cache.checkEquivalent(lhsValue, rhsValue);
826 },
827 [&](Value lhsResult, Value rhsResult) {
828 cache.markEquivalent(lhsResult, rhsResult);
829 },
830 flags,
831 [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
832 return cache.checkCommutativeEquivalent(lhs, rhs);
833 });
834}
835
838 function_ref<LogicalResult(Value, Value)> checkEquivalent,
839 function_ref<void(Value, Value)> markEquivalent, Flags flags,
840 function_ref<LogicalResult(ValueRange, ValueRange)>
841 checkCommutativeEquivalent) {
842 if (lhs == rhs)
843 return true;
844
845 // 1. Compare the operation properties.
846 if (!(flags & IgnoreDiscardableAttrs) &&
847 lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs())
848 return false;
849
850 if (lhs->getName() != rhs->getName() ||
851 lhs->getNumRegions() != rhs->getNumRegions() ||
852 lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
853 lhs->getNumOperands() != rhs->getNumOperands() ||
854 lhs->getNumResults() != rhs->getNumResults())
855 return false;
856 if (!(flags & IgnoreProperties) &&
857 !(lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
858 rhs->getPropertiesStorage())))
859 return false;
860 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
861 return false;
862
863 // 2. Compare operands.
864 if (!(flags & IgnoreCommutativity) && checkCommutativeEquivalent &&
865 lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
866 auto lhsRange = lhs->getOperands();
867 auto rhsRange = rhs->getOperands();
868 if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
869 return false;
870 } else {
871 // Check pair wise for equivalence.
872 for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
873 Value curArg = std::get<0>(operandPair);
874 Value otherArg = std::get<1>(operandPair);
875 if (curArg == otherArg)
876 continue;
877 if (curArg.getType() != otherArg.getType())
878 return false;
879 if (failed(checkEquivalent(curArg, otherArg)))
880 return false;
881 }
882 }
883
884 // 3. Compare result types and mark results as equivalent.
885 for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
886 Value curArg = std::get<0>(resultPair);
887 Value otherArg = std::get<1>(resultPair);
888 if (curArg.getType() != otherArg.getType())
889 return false;
890 if (markEquivalent)
891 markEquivalent(curArg, otherArg);
892 }
893
894 // 4. Compare regions.
895 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
896 if (!isRegionEquivalentTo(&std::get<0>(regionPair),
897 &std::get<1>(regionPair), checkEquivalent,
898 markEquivalent, flags))
899 return false;
900
901 return true;
902}
903
905 Operation *rhs,
906 Flags flags) {
909 lhs, rhs,
910 [&](Value lhsValue, Value rhsValue) -> LogicalResult {
911 return cache.checkEquivalent(lhsValue, rhsValue);
912 },
913 [&](Value lhsResult, Value rhsResult) {
914 cache.markEquivalent(lhsResult, rhsResult);
915 },
916 flags,
917 [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
918 return cache.checkCommutativeEquivalent(lhs, rhs);
919 });
920}
921
922//===----------------------------------------------------------------------===//
923// OperationFingerPrint
924//===----------------------------------------------------------------------===//
925
926template <typename T>
927static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
928 hasher.update(
929 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
930}
931
933 bool includeNested) {
934 llvm::SHA1 hasher;
935
936 // Helper function that hashes an operation based on its mutable bits:
937 auto addOperationToHash = [&](Operation *op) {
938 // - Operation pointer
939 addDataToHash(hasher, op);
940 // - Parent operation pointer (to take into account the nesting structure)
941 if (op != topOp)
942 addDataToHash(hasher, op->getParentOp());
943 // - Attributes
945 // - Properties
946 addDataToHash(hasher, op->hashProperties());
947 // - Blocks in Regions
948 for (Region &region : op->getRegions()) {
949 for (Block &block : region) {
950 addDataToHash(hasher, &block);
951 for (BlockArgument arg : block.getArguments())
952 addDataToHash(hasher, arg);
953 }
954 }
955 // - Location
956 addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
957 // - Operands
958 for (Value operand : op->getOperands())
959 addDataToHash(hasher, operand);
960 // - Successors
961 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
962 addDataToHash(hasher, op->getSuccessor(i));
963 // - Result types
964 for (Type t : op->getResultTypes())
965 addDataToHash(hasher, t);
966 };
967
968 if (includeNested)
969 topOp->walk(addOperationToHash);
970 else
971 addOperationToHash(topOp);
972
973 hash = hasher.result();
974}
975
return success()
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static size_t hash(const T &value)
Local helper to compute std::hash for a value.
Definition IRCore.cpp:55
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
false
Parses a map_entries map type from a string format back into its numeric value.
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class represents an argument of a Block.
Definition Value.h:306
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
unsigned getNumArguments()
Definition Block.h:138
BlockArgListType getArguments()
Definition Block.h:97
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Location.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a contiguous range of mutable operand ranges, e.g.
Definition ValueRange.h:211
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
Operation * getOwner() const
Returns the owning operation.
Definition ValueRange.h:172
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments={})
Construct a new mutable range from the given operand, operand start index, and range length.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition ValueRange.h:124
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:85
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
PropertyRef getPropertiesStorage()
Return a generic (but typed) reference to the property type storage.
Definition Operation.h:927
Value getOperand(unsigned idx)
Definition Operation.h:376
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
unsigned getNumSuccessors()
Definition Operation.h:732
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition Operation.h:535
unsigned getNumOperands()
Definition Operation.h:372
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
Block * getSuccessor(unsigned index)
Definition Operation.h:734
SuccessorRange getSuccessors()
Definition Operation.h:729
result_range getResults()
Definition Operation.h:441
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class implements a use iterator for a range of operation results.
Definition ValueRange.h:351
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition ValueRange.h:248
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition ValueRange.h:304
iterator_range< user_iterator > user_range
Definition ValueRange.h:324
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition ValueRange.h:323
ResultRange(OpResult result)
use_iterator use_end() const
user_iterator user_end()
user_iterator user_begin()
iterator_range< use_iterator > use_range
Definition ValueRange.h:269
UseIterator use_iterator
Definition ValueRange.h:268
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition ValueRange.h:287
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl *, const Repeated< Value > * > OwnerT
The type representing the owner of a ValueRange.
Definition ValueRange.h:393
ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
Definition ValueRange.h:402
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
MutableArrayRef< OpOperand > getOperands()
Get the operation operands held by the storage.
unsigned size()
Return the number of operands held in the storage.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
AttrTypeReplacer.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
Structure used by default as a "marker" when no "Properties" are set on an Operation.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
void addRegions(MutableArrayRef< std::unique_ptr< Region > > regions)
Take ownership of a set of regions that should be attached to the Operation.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const