MLIR 22.0.0git
OperationSupport.cpp
Go to the documentation of this file.
1//===- OperationSupport.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains out-of-line implementations of the support types that
10// Operation and related classes build on top of.
11//
12//===----------------------------------------------------------------------===//
13
18#include "llvm/Support/SHA1.h"
19#include <numeric>
20#include <optional>
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// NamedAttrList
26//===----------------------------------------------------------------------===//
27
29 assign(attributes.begin(), attributes.end());
30}
31
32NamedAttrList::NamedAttrList(DictionaryAttr attributes)
33 : NamedAttrList(attributes ? attributes.getValue()
35 dictionarySorted.setPointerAndInt(attributes, true);
36}
37
39 assign(inStart, inEnd);
40}
41
43
44std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
45 std::optional<NamedAttribute> duplicate =
46 DictionaryAttr::findDuplicate(attrs, isSorted());
47 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
48 // state.
49 if (!isSorted())
50 dictionarySorted.setPointerAndInt(nullptr, true);
51 return duplicate;
52}
53
54DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
55 if (!isSorted()) {
56 DictionaryAttr::sortInPlace(attrs);
57 dictionarySorted.setPointerAndInt(nullptr, true);
58 }
59 if (!dictionarySorted.getPointer())
60 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
61 return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
62}
63
64/// Replaces the attributes with new list of attributes.
66 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
67 dictionarySorted.setPointerAndInt(nullptr, true);
68}
69
71 if (isSorted())
72 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
73 dictionarySorted.setPointer(nullptr);
74 attrs.push_back(newAttribute);
75}
76
77/// Return the specified attribute if present, null otherwise.
78Attribute NamedAttrList::get(StringRef name) const {
79 auto it = findAttr(*this, name);
80 return it.second ? it.first->getValue() : Attribute();
81}
82Attribute NamedAttrList::get(StringAttr name) const {
83 auto it = findAttr(*this, name);
84 return it.second ? it.first->getValue() : Attribute();
85}
86
87/// Return the specified named attribute if present, std::nullopt otherwise.
88std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
89 auto it = findAttr(*this, name);
90 return it.second ? *it.first : std::optional<NamedAttribute>();
91}
92std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
93 auto it = findAttr(*this, name);
94 return it.second ? *it.first : std::optional<NamedAttribute>();
95}
96
97/// If the an attribute exists with the specified name, change it to the new
98/// value. Otherwise, add a new attribute with the specified name/value.
99Attribute NamedAttrList::set(StringAttr name, Attribute value) {
100 assert(value && "attributes may never be null");
101
102 // Look for an existing attribute with the given name, and set its value
103 // in-place. Return the previous value of the attribute, if there was one.
104 auto it = findAttr(*this, name);
105 if (it.second) {
106 // Update the existing attribute by swapping out the old value for the new
107 // value. Return the old value.
108 Attribute oldValue = it.first->getValue();
109 if (it.first->getValue() != value) {
110 it.first->setValue(value);
111
112 // If the attributes have changed, the dictionary is invalidated.
113 dictionarySorted.setPointer(nullptr);
114 }
115 return oldValue;
116 }
117 // Perform a string lookup to insert the new attribute into its sorted
118 // position.
119 if (isSorted())
120 it = findAttr(*this, name.strref());
121 attrs.insert(it.first, {name, value});
122 // Invalidate the dictionary. Return null as there was no previous value.
123 dictionarySorted.setPointer(nullptr);
124 return Attribute();
125}
126
127Attribute NamedAttrList::set(StringRef name, Attribute value) {
128 assert(value && "attributes may never be null");
129 return set(mlir::StringAttr::get(value.getContext(), name), value);
130}
131
133NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
134 // Erasing does not affect the sorted property.
135 Attribute attr = it->getValue();
136 attrs.erase(it);
137 dictionarySorted.setPointer(nullptr);
138 return attr;
139}
140
142 auto it = findAttr(*this, name);
143 return it.second ? eraseImpl(it.first) : Attribute();
144}
145
147 auto it = findAttr(*this, name);
148 return it.second ? eraseImpl(it.first) : Attribute();
149}
150
153 assign(rhs.begin(), rhs.end());
154 return *this;
155}
156
157NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
158
159//===----------------------------------------------------------------------===//
160// OperationState
161//===----------------------------------------------------------------------===//
162
165
168
173 MutableArrayRef<std::unique_ptr<Region>> regions)
175 operands(operands.begin(), operands.end()),
176 types(types.begin(), types.end()),
177 attributes(attributes.begin(), attributes.end()),
178 successors(successors.begin(), successors.end()) {
179 for (std::unique_ptr<Region> &r : regions)
180 this->regions.push_back(std::move(r));
181}
189
191 if (properties)
192 propertiesDeleter(properties);
193}
194
197 if (LLVM_UNLIKELY(propertiesAttr)) {
198 assert(!properties);
200 }
201 if (properties)
202 propertiesSetter(op->getPropertiesStorage(), properties);
203 return success();
204}
205
207 operands.append(newOperands.begin(), newOperands.end());
208}
209
211 successors.append(newSuccessors.begin(), newSuccessors.end());
212}
213
215 regions.emplace_back(new Region);
216 return regions.back().get();
217}
218
219void OperationState::addRegion(std::unique_ptr<Region> &&region) {
220 regions.push_back(std::move(region));
221}
222
224 MutableArrayRef<std::unique_ptr<Region>> regions) {
225 for (std::unique_ptr<Region> &region : regions)
226 addRegion(std::move(region));
227}
228
229//===----------------------------------------------------------------------===//
230// OperandStorage
231//===----------------------------------------------------------------------===//
232
234 OpOperand *trailingOperands,
235 ValueRange values)
236 : isStorageDynamic(false), operandStorage(trailingOperands) {
237 numOperands = capacity = values.size();
238 for (unsigned i = 0; i < numOperands; ++i)
239 new (&operandStorage[i]) OpOperand(owner, values[i]);
240}
241
243 for (auto &operand : getOperands())
244 operand.~OpOperand();
245
246 // If the storage is dynamic, deallocate it.
247 if (isStorageDynamic)
248 free(operandStorage);
249}
250
251/// Replace the operands contained in the storage with the ones provided in
252/// 'values'.
254 MutableArrayRef<OpOperand> storageOperands = resize(owner, values.size());
255 for (unsigned i = 0, e = values.size(); i != e; ++i)
256 storageOperands[i].set(values[i]);
257}
258
259/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
260/// with the ones provided in 'operands'. 'operands' may be smaller or larger
261/// than the range pointed to by 'start'+'length'.
263 unsigned length, ValueRange operands) {
264 // If the new size is the same, we can update inplace.
265 unsigned newSize = operands.size();
266 if (newSize == length) {
267 MutableArrayRef<OpOperand> storageOperands = getOperands();
268 for (unsigned i = 0, e = length; i != e; ++i)
269 storageOperands[start + i].set(operands[i]);
270 return;
271 }
272 // If the new size is greater, remove the extra operands and set the rest
273 // inplace.
274 if (newSize < length) {
275 eraseOperands(start + operands.size(), length - newSize);
276 setOperands(owner, start, newSize, operands);
277 return;
278 }
279 // Otherwise, the new size is greater so we need to grow the storage.
280 auto storageOperands = resize(owner, size() + (newSize - length));
281
282 // Shift operands to the right to make space for the new operands.
283 unsigned rotateSize = storageOperands.size() - (start + length);
284 auto rbegin = storageOperands.rbegin();
285 std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
286
287 // Update the operands inplace.
288 for (unsigned i = 0, e = operands.size(); i != e; ++i)
289 storageOperands[start + i].set(operands[i]);
290}
291
292/// Erase an operand held by the storage.
293void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
295 assert((start + length) <= operands.size());
296 numOperands -= length;
297
298 // Shift all operands down if the operand to remove is not at the end.
299 if (start != numOperands) {
300 auto *indexIt = std::next(operands.begin(), start);
301 std::rotate(indexIt, std::next(indexIt, length), operands.end());
302 }
303 for (unsigned i = 0; i != length; ++i)
304 operands[numOperands + i].~OpOperand();
305}
306
307void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
309 assert(eraseIndices.size() == operands.size());
310
311 // Check that at least one operand is erased.
312 int firstErasedIndice = eraseIndices.find_first();
313 if (firstErasedIndice == -1)
314 return;
315
316 // Shift all of the removed operands to the end, and destroy them.
317 numOperands = firstErasedIndice;
318 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
319 if (!eraseIndices.test(i))
320 operands[numOperands++] = std::move(operands[i]);
321 for (OpOperand &operand : operands.drop_front(numOperands))
322 operand.~OpOperand();
323}
324
325/// Resize the storage to the given size. Returns the array containing the new
326/// operands.
327MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
328 unsigned newSize) {
329 // If the number of operands is less than or equal to the current amount, we
330 // can just update in place.
331 MutableArrayRef<OpOperand> origOperands = getOperands();
332 if (newSize <= numOperands) {
333 // If the number of new size is less than the current, remove any extra
334 // operands.
335 for (unsigned i = newSize; i != numOperands; ++i)
336 origOperands[i].~OpOperand();
337 numOperands = newSize;
338 return origOperands.take_front(newSize);
339 }
340
341 // If the new size is within the original inline capacity, grow inplace.
342 if (newSize <= capacity) {
343 OpOperand *opBegin = origOperands.data();
344 for (unsigned e = newSize; numOperands != e; ++numOperands)
345 new (&opBegin[numOperands]) OpOperand(owner);
346 return MutableArrayRef<OpOperand>(opBegin, newSize);
347 }
348
349 // Otherwise, we need to allocate a new storage.
350 unsigned newCapacity =
351 std::max(unsigned(llvm::NextPowerOf2(capacity + 2)), newSize);
352 OpOperand *newOperandStorage =
353 reinterpret_cast<OpOperand *>(malloc(sizeof(OpOperand) * newCapacity));
354
355 // Move the current operands to the new storage.
356 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
357 std::uninitialized_move(origOperands.begin(), origOperands.end(),
358 newOperands.begin());
359
360 // Destroy the original operands.
361 for (auto &operand : origOperands)
362 operand.~OpOperand();
363
364 // Initialize any new operands.
365 for (unsigned e = newSize; numOperands != e; ++numOperands)
366 new (&newOperands[numOperands]) OpOperand(owner);
367
368 // If the current storage is dynamic, free it.
369 if (isStorageDynamic)
370 free(operandStorage);
371
372 // Update the storage representation to use the new dynamic storage.
373 operandStorage = newOperandStorage;
374 capacity = newCapacity;
375 isStorageDynamic = true;
376 return newOperands;
377}
378
379//===----------------------------------------------------------------------===//
380// Operation Value-Iterators
381//===----------------------------------------------------------------------===//
382
383//===----------------------------------------------------------------------===//
384// OperandRange
385//===----------------------------------------------------------------------===//
386
388 assert(!empty() && "range must not be empty");
389 return base->getOperandNumber();
390}
391
393 return OperandRangeRange(*this, segmentSizes);
394}
395
396//===----------------------------------------------------------------------===//
397// OperandRangeRange
398//===----------------------------------------------------------------------===//
399
401 Attribute operandSegments)
402 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
403 llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
404}
405
407 const OwnerT &owner = getBase();
408 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
409 return OperandRange(owner.first, llvm::sum_of(sizeData));
410}
411
412OperandRange OperandRangeRange::dereference(const OwnerT &object,
413 ptrdiff_t index) {
414 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
415 uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
416 return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
417}
418
419//===----------------------------------------------------------------------===//
420// MutableOperandRange
421//===----------------------------------------------------------------------===//
422
423/// Construct a new mutable range from the given operand, operand start index,
424/// and range length.
426 Operation *owner, unsigned start, unsigned length,
427 ArrayRef<OperandSegment> operandSegments)
428 : owner(owner), start(start), length(length),
429 operandSegments(operandSegments) {
430 assert((start + length) <= owner->getNumOperands() && "invalid range");
431}
433 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
434
435/// Construct a new mutable range for the given OpOperand.
437 : MutableOperandRange(opOperand.getOwner(),
438 /*start=*/opOperand.getOperandNumber(),
439 /*length=*/1) {}
440
441/// Slice this range into a sub range, with the additional operand segment.
443MutableOperandRange::slice(unsigned subStart, unsigned subLen,
444 std::optional<OperandSegment> segment) const {
445 assert((subStart + subLen) <= length && "invalid sub-range");
446 MutableOperandRange subSlice(owner, start + subStart, subLen,
447 operandSegments);
448 if (segment)
449 subSlice.operandSegments.push_back(*segment);
450 return subSlice;
451}
452
453/// Append the given values to the range.
455 if (values.empty())
456 return;
457 owner->insertOperands(start + length, values);
458 updateLength(length + values.size());
459}
460
461/// Assign this range to the given values.
463 owner->setOperands(start, length, values);
464 if (length != values.size())
465 updateLength(/*newLength=*/values.size());
466}
467
468/// Assign the range to the given value.
470 if (length == 1) {
471 owner->setOperand(start, value);
472 } else {
473 owner->setOperands(start, length, value);
474 updateLength(/*newLength=*/1);
475 }
476}
477
478/// Erase the operands within the given sub-range.
479void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
480 assert((subStart + subLen) <= length && "invalid sub-range");
481 if (length == 0)
482 return;
483 owner->eraseOperands(start + subStart, subLen);
484 updateLength(length - subLen);
485}
486
487/// Clear this range and erase all of the operands.
489 if (length != 0) {
490 owner->eraseOperands(start, length);
491 updateLength(/*newLength=*/0);
492 }
493}
494
495/// Explicit conversion to an OperandRange.
497 return owner->getOperands().slice(start, length);
498}
499
500/// Allow implicit conversion to an OperandRange.
501MutableOperandRange::operator OperandRange() const {
502 return getAsOperandRange();
503}
504
505MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
506 return owner->getOpOperands().slice(start, length);
507}
508
511 return MutableOperandRangeRange(*this, segmentSizes);
512}
513
514/// Update the length of this range to the one provided.
515void MutableOperandRange::updateLength(unsigned newLength) {
516 int32_t diff = int32_t(newLength) - int32_t(length);
517 length = newLength;
518
519 // Update any of the provided segment attributes.
520 for (OperandSegment &segment : operandSegments) {
521 auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
522 SmallVector<int32_t, 8> segments(attr.asArrayRef());
523 segments[segment.first] += diff;
524 segment.second.setValue(
525 DenseI32ArrayAttr::get(attr.getContext(), segments));
526 owner->setAttr(segment.second.getName(), segment.second.getValue());
527 }
528}
529
531 assert(index < length && "index is out of bounds");
532 return owner->getOpOperand(start + index);
533}
534
536 return owner->getOpOperands().slice(start, length).begin();
537}
538
540 return owner->getOpOperands().slice(start, length).end();
541}
542
543//===----------------------------------------------------------------------===//
544// MutableOperandRangeRange
545//===----------------------------------------------------------------------===//
546
548 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
550 OwnerT(operands, operandSegmentAttr), 0,
551 llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
552}
553
557
558MutableOperandRangeRange::operator OperandRangeRange() const {
559 return OperandRangeRange(getBase().first, getBase().second.getValue());
560}
561
562MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
563 ptrdiff_t index) {
564 ArrayRef<int32_t> sizeData =
565 llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
566 uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
567 return object.first.slice(
568 startIndex, *(sizeData.begin() + index),
570}
571
572//===----------------------------------------------------------------------===//
573// ResultRange
574//===----------------------------------------------------------------------===//
575
577 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
578 1) {}
579
587 return use_iterator(*this, /*end=*/true);
588}
598
600 : it(end ? results.end() : results.begin()), endIt(results.end()) {
601 // Only initialize current use if there are results/can be uses.
602 if (it != endIt)
603 skipOverResultsWithNoUsers();
604}
605
607 // We increment over uses, if we reach the last use then move to next
608 // result.
609 if (use != (*it).use_end())
610 ++use;
611 if (use == (*it).use_end()) {
612 ++it;
613 skipOverResultsWithNoUsers();
614 }
615 return *this;
616}
617
618void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
619 while (it != endIt && (*it).use_empty())
620 ++it;
621
622 // If we are at the last result, then set use to first use of
623 // first result (sentinel value used for end).
624 if (it == endIt)
625 use = {};
626 else
627 use = (*it).use_begin();
628}
629
633
635 Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
636 replaceUsesWithIf(op->getResults(), shouldReplace);
637}
638
639//===----------------------------------------------------------------------===//
640// ValueRange
641//===----------------------------------------------------------------------===//
642
644 : ValueRange(values.data(), values.size()) {}
646 : ValueRange(values.begin().getBase(), values.size()) {}
648 : ValueRange(values.getBase(), values.size()) {}
649
650/// See `llvm::detail::indexed_accessor_range_base` for details.
651ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
652 ptrdiff_t index) {
653 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
654 return {value + index};
655 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
656 return {operand + index};
657 return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
658}
659/// See `llvm::detail::indexed_accessor_range_base` for details.
660Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
661 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
662 return value[index];
663 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
664 return operand[index].get();
665 return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
666}
667
668//===----------------------------------------------------------------------===//
669// Operation Equivalency
670//===----------------------------------------------------------------------===//
671
673 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
674 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
675 // Hash operations based upon their:
676 // - Operation Name
677 // - Attributes
678 // - Result Types
679 DictionaryAttr dictAttrs;
680 if (!(flags & Flags::IgnoreDiscardableAttrs))
681 dictAttrs = op->getRawDictionaryAttrs();
682 llvm::hash_code hash =
683 llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
684 if (!(flags & Flags::IgnoreProperties))
685 hash = llvm::hash_combine(hash, op->hashProperties());
686
687 // - Location if required
688 if (!(flags & Flags::IgnoreLocations))
689 hash = llvm::hash_combine(hash, op->getLoc());
690
691 // - Operands
693 op->getNumOperands() > 0) {
694 size_t operandHash = hashOperands(op->getOperand(0));
695 for (auto operand : op->getOperands().drop_front())
696 operandHash += hashOperands(operand);
697 hash = llvm::hash_combine(hash, operandHash);
698 } else {
699 for (Value operand : op->getOperands())
700 hash = llvm::hash_combine(hash, hashOperands(operand));
701 }
702
703 // - Results
704 for (Value result : op->getResults())
705 hash = llvm::hash_combine(hash, hashResults(result));
706 return hash;
707}
708
710 Region *lhs, Region *rhs,
711 function_ref<LogicalResult(Value, Value)> checkEquivalent,
712 function_ref<void(Value, Value)> markEquivalent,
714 function_ref<LogicalResult(ValueRange, ValueRange)>
715 checkCommutativeEquivalent) {
717 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
718 // Check block arguments.
719 if (lBlock.getNumArguments() != rBlock.getNumArguments())
720 return false;
721
722 // Map the two blocks.
723 auto insertion = blocksMap.insert({&lBlock, &rBlock});
724 if (insertion.first->getSecond() != &rBlock)
725 return false;
726
727 for (auto argPair :
728 llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
729 Value curArg = std::get<0>(argPair);
730 Value otherArg = std::get<1>(argPair);
731 if (curArg.getType() != otherArg.getType())
732 return false;
734 curArg.getLoc() != otherArg.getLoc())
735 return false;
736 // Corresponding bbArgs are equivalent.
737 if (markEquivalent)
738 markEquivalent(curArg, otherArg);
739 }
740
741 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
742 // Check for op equality (recursively).
743 if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
744 markEquivalent, flags,
745 checkCommutativeEquivalent))
746 return false;
747 // Check successor mapping.
748 for (auto successorsPair :
749 llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
750 Block *curSuccessor = std::get<0>(successorsPair);
751 Block *otherSuccessor = std::get<1>(successorsPair);
752 auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
753 if (insertion.first->getSecond() != otherSuccessor)
754 return false;
755 }
756 return true;
757 };
758 return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
759 };
760 return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
761}
762
763// Value equivalence cache to be used with `isRegionEquivalentTo` and
764// `isEquivalentTo`.
767 LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
768 return success(lhsValue == rhsValue ||
769 equivalentValues.lookup(lhsValue) == rhsValue);
770 }
772 ValueRange rhsRange) {
773 // Handle simple case where sizes mismatch.
774 if (lhsRange.size() != rhsRange.size())
775 return failure();
776
777 // Handle where operands in order are equivalent.
778 auto lhsIt = lhsRange.begin();
779 auto rhsIt = rhsRange.begin();
780 for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
781 if (failed(checkEquivalent(*lhsIt, *rhsIt)))
782 break;
783 }
784 if (lhsIt == lhsRange.end())
785 return success();
786
787 // Handle another simple case where operands are just a permutation.
788 // Note: This is not sufficient, this handles simple cases relatively
789 // cheaply.
790 auto sortValues = [](ValueRange values) {
791 SmallVector<Value> sortedValues = llvm::to_vector(values);
792 llvm::sort(sortedValues, [](Value a, Value b) {
793 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
794 });
795 return sortedValues;
796 };
797 auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
798 auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
799 return success(lhsSorted == rhsSorted);
800 }
801 void markEquivalent(Value lhsResult, Value rhsResult) {
802 auto insertion = equivalentValues.insert({lhsResult, rhsResult});
803 // Make sure that the value was not already marked equivalent to some other
804 // value.
805 (void)insertion;
806 assert(insertion.first->second == rhsResult &&
807 "inconsistent OperationEquivalence state");
808 }
809};
810
811/*static*/ bool
816 lhs, rhs,
817 [&](Value lhsValue, Value rhsValue) -> LogicalResult {
818 return cache.checkEquivalent(lhsValue, rhsValue);
819 },
820 [&](Value lhsResult, Value rhsResult) {
821 cache.markEquivalent(lhsResult, rhsResult);
822 },
823 flags,
824 [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
825 return cache.checkCommutativeEquivalent(lhs, rhs);
826 });
827}
828
831 function_ref<LogicalResult(Value, Value)> checkEquivalent,
832 function_ref<void(Value, Value)> markEquivalent, Flags flags,
833 function_ref<LogicalResult(ValueRange, ValueRange)>
834 checkCommutativeEquivalent) {
835 if (lhs == rhs)
836 return true;
837
838 // 1. Compare the operation properties.
839 if (!(flags & IgnoreDiscardableAttrs) &&
840 lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs())
841 return false;
842
843 if (lhs->getName() != rhs->getName() ||
844 lhs->getNumRegions() != rhs->getNumRegions() ||
845 lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
846 lhs->getNumOperands() != rhs->getNumOperands() ||
847 lhs->getNumResults() != rhs->getNumResults())
848 return false;
849 if (!(flags & IgnoreProperties) &&
850 !(lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
851 rhs->getPropertiesStorage())))
852 return false;
853 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
854 return false;
855
856 // 2. Compare operands.
857 if (checkCommutativeEquivalent &&
858 lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
859 auto lhsRange = lhs->getOperands();
860 auto rhsRange = rhs->getOperands();
861 if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
862 return false;
863 } else {
864 // Check pair wise for equivalence.
865 for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
866 Value curArg = std::get<0>(operandPair);
867 Value otherArg = std::get<1>(operandPair);
868 if (curArg == otherArg)
869 continue;
870 if (curArg.getType() != otherArg.getType())
871 return false;
872 if (failed(checkEquivalent(curArg, otherArg)))
873 return false;
874 }
875 }
876
877 // 3. Compare result types and mark results as equivalent.
878 for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
879 Value curArg = std::get<0>(resultPair);
880 Value otherArg = std::get<1>(resultPair);
881 if (curArg.getType() != otherArg.getType())
882 return false;
883 if (markEquivalent)
884 markEquivalent(curArg, otherArg);
885 }
886
887 // 4. Compare regions.
888 for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
889 if (!isRegionEquivalentTo(&std::get<0>(regionPair),
890 &std::get<1>(regionPair), checkEquivalent,
891 markEquivalent, flags))
892 return false;
893
894 return true;
895}
896
898 Operation *rhs,
899 Flags flags) {
902 lhs, rhs,
903 [&](Value lhsValue, Value rhsValue) -> LogicalResult {
904 return cache.checkEquivalent(lhsValue, rhsValue);
905 },
906 [&](Value lhsResult, Value rhsResult) {
907 cache.markEquivalent(lhsResult, rhsResult);
908 },
909 flags,
910 [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
911 return cache.checkCommutativeEquivalent(lhs, rhs);
912 });
913}
914
915//===----------------------------------------------------------------------===//
916// OperationFingerPrint
917//===----------------------------------------------------------------------===//
918
919template <typename T>
920static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
921 hasher.update(
922 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
923}
924
926 bool includeNested) {
927 llvm::SHA1 hasher;
928
929 // Helper function that hashes an operation based on its mutable bits:
930 auto addOperationToHash = [&](Operation *op) {
931 // - Operation pointer
932 addDataToHash(hasher, op);
933 // - Parent operation pointer (to take into account the nesting structure)
934 if (op != topOp)
935 addDataToHash(hasher, op->getParentOp());
936 // - Attributes
938 // - Properties
939 addDataToHash(hasher, op->hashProperties());
940 // - Blocks in Regions
941 for (Region &region : op->getRegions()) {
942 for (Block &block : region) {
943 addDataToHash(hasher, &block);
944 for (BlockArgument arg : block.getArguments())
945 addDataToHash(hasher, arg);
946 }
947 }
948 // - Location
949 addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
950 // - Operands
951 for (Value operand : op->getOperands())
952 addDataToHash(hasher, operand);
953 // - Successors
954 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
955 addDataToHash(hasher, op->getSuccessor(i));
956 // - Result types
957 for (Type t : op->getResultTypes())
958 addDataToHash(hasher, t);
959 };
960
961 if (includeNested)
962 topOp->walk(addOperationToHash);
963 else
964 addOperationToHash(topOp);
965
966 hash = hasher.result();
967}
return success()
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
false
Parses a map_entries map type from a string format back into its numeric value.
static void addDataToHash(llvm::SHA1 &hasher, const T &data)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class represents an argument of a Block.
Definition Value.h:309
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
unsigned getNumArguments()
Definition Block.h:128
BlockArgListType getArguments()
Definition Block.h:87
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Location.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a contiguous range of mutable operand ranges, e.g.
Definition ValueRange.h:210
MutableOperandRange join() const
Flatten all of the sub ranges into a single contiguous mutable operand range.
MutableOperandRangeRange(const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
Construct a range given a parent set of operands, and an I32 tensor elements attribute containing the...
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
Operation * getOwner() const
Returns the owning operation.
Definition ValueRange.h:171
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
void assign(ValueRange values)
Assign this range to the given values.
MutableOperandRange slice(unsigned subStart, unsigned subLen, std::optional< OperandSegment > segment=std::nullopt) const
Slice this range into a sub range, with the additional operand segment.
MutableArrayRef< OpOperand >::iterator end() const
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
void append(ValueRange values)
Append the given values to the range.
void clear()
Clear this range and erase all of the operands.
MutableArrayRef< OpOperand >::iterator begin() const
Iterators enumerate OpOperands.
MutableOperandRange(Operation *owner, unsigned start, unsigned length, ArrayRef< OperandSegment > operandSegments={})
Construct a new mutable range from the given operand, operand start index, and range length.
std::pair< unsigned, NamedAttribute > OperandSegment
A pair of a named attribute corresponding to an operand segment attribute, and the index within that ...
Definition ValueRange.h:123
MutableOperandRangeRange split(NamedAttribute segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OpOperand & operator[](unsigned index) const
Returns the OpOperand at the given index.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
SmallVectorImpl< NamedAttribute >::const_iterator const_iterator
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
std::optional< NamedAttribute > findDuplicate() const
Returns an entry with a duplicate name the list, if it exists, else returns std::nullopt.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttrList & operator=(const SmallVectorImpl< NamedAttribute > &rhs)
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class adds property that the operation is commutative.
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:84
OperandRangeRange(OperandRange operands, Attribute operandSegments)
Construct a range given a parent set of operands, and an I32 elements attribute containing the sizes ...
OperandRange join() const
Flatten all of the sub ranges into a single contiguous operand range.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const
Split this range into a set of contiguous subranges using the given elements attribute,...
OperationFingerPrint(Operation *topOp, bool includeNested=true)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
unsigned getNumSuccessors()
Definition Operation.h:706
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
Definition Operation.h:509
unsigned getNumOperands()
Definition Operation.h:346
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
LogicalResult setPropertiesFromAttribute(Attribute attr, function_ref< InFlightDiagnostic()> emitError)
Set the properties from the provided attribute.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
Block * getSuccessor(unsigned index)
Definition Operation.h:708
SuccessorRange getSuccessors()
Definition Operation.h:703
result_range getResults()
Definition Operation.h:415
llvm::hash_code hashProperties()
Compute a hash for the op properties (if any).
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class implements a use iterator for a range of operation results.
Definition ValueRange.h:350
UseIterator(ResultRange results, bool end=false)
Initialize the UseIterator.
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceUsesWithIf(ValuesT &&values, function_ref< bool(OpOperand &)> shouldReplace)
Replace uses of results of this range with the provided 'values' if the given callback returns true.
Definition ValueRange.h:303
iterator_range< user_iterator > user_range
Definition ValueRange.h:323
use_range getUses() const
Returns a range of all uses of results within this range, which is useful for iterating over all uses...
use_iterator use_begin() const
user_range getUsers()
Returns a range of all users.
ValueUserIterator< use_iterator, OpOperand > user_iterator
Definition ValueRange.h:322
ResultRange(OpResult result)
use_iterator use_end() const
user_iterator user_end()
user_iterator user_begin()
iterator_range< use_iterator > use_range
Definition ValueRange.h:268
UseIterator use_iterator
Definition ValueRange.h:267
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition ValueRange.h:286
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
PointerUnion< const Value *, OpOperand *, detail::OpResultImpl * > OwnerT
The type representing the owner of a ValueRange.
Definition ValueRange.h:391
ValueRange(Arg &&arg LLVM_LIFETIME_BOUND)
Definition ValueRange.h:400
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void eraseOperands(unsigned start, unsigned length)
Erase the operands held by the storage within the given range.
MutableArrayRef< OpOperand > getOperands()
Get the operation operands held by the storage.
unsigned size()
Return the number of operands held in the storage.
void setOperands(Operation *owner, ValueRange values)
Replace the operands contained in the storage with the ones provided in 'values'.
OperandStorage(Operation *owner, OpOperand *trailingOperands, ValueRange values)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
AttrTypeReplacer.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
LogicalResult checkEquivalent(Value lhsValue, Value rhsValue)
void markEquivalent(Value lhsResult, Value rhsResult)
LogicalResult checkCommutativeEquivalent(ValueRange lhsRange, ValueRange rhsRange)
DenseMap< Value, Value > equivalentValues
static bool isRegionEquivalentTo(Region *lhs, Region *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent, OperationEquivalence::Flags flags, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two regions (including their subregions) and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
void addRegions(MutableArrayRef< std::unique_ptr< Region > > regions)
Take ownership of a set of regions that should be attached to the Operation.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
OperationState(Location location, StringRef name)
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult setProperties(Operation *op, function_ref< InFlightDiagnostic()> emitError) const