9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
17 #include "llvm/Support/raw_ostream.h"
39 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
43 new (&
nonConState) NonContiguousState(std::move(rhs.nonConState));
46 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
60 template <
typename IteratorT>
64 NonContiguousState(std::forward<IteratorT>(iterator));
74 new (&indexer.
conState) ContiguousState(firstEltPtr);
80 T
at(uint64_t index)
const {
88 : isContiguous(isContiguous), isSplat(isSplat),
conState(nullptr) {}
92 class ContiguousState {
94 ContiguousState(
const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
98 const T &
at(uint64_t index)
const {
99 return *(
reinterpret_cast<const T *
>(firstEltPtr) + index);
103 const void *firstEltPtr;
108 class NonContiguousState {
113 struct OpaqueIteratorBase {
114 virtual ~OpaqueIteratorBase() =
default;
115 virtual std::unique_ptr<OpaqueIteratorBase>
clone()
const = 0;
120 template <
typename T>
121 struct OpaqueIteratorValueBase :
public OpaqueIteratorBase {
122 virtual T
at(uint64_t index) = 0;
126 template <
typename IteratorT,
typename T>
127 struct OpaqueIterator :
public OpaqueIteratorValueBase<T> {
128 template <
typename ItTy,
typename FuncTy,
typename FuncReturnTy>
129 static void isMappedIteratorTestFn(
130 llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
131 template <
typename U,
typename... Args>
132 using is_mapped_iterator =
133 decltype(isMappedIteratorTestFn(std::declval<U>()));
134 template <
typename U>
135 using detect_is_mapped_iterator =
136 llvm::is_detected<is_mapped_iterator, U>;
139 template <
typename ItT>
140 static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
141 atImpl(ItT &&it, uint64_t index) {
142 return *std::next(it, index);
144 template <
typename ItT>
145 static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
146 atImpl(ItT &&it, uint64_t index) {
148 return it.getFunction()(*std::next(it.getCurrent(), index));
152 template <
typename U>
153 OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
154 std::unique_ptr<OpaqueIteratorBase>
clone() const final {
155 return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
159 T
at(uint64_t index)
final {
return atImpl(iterator, index); }
167 template <
typename IteratorT,
typename T =
typename llvm::remove_cvref_t<
168 decltype(*std::declval<IteratorT>())>>
169 NonContiguousState(IteratorT iterator)
170 : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
171 NonContiguousState(
const NonContiguousState &other)
172 : iterator(other.iterator->
clone()) {}
173 NonContiguousState(NonContiguousState &&other) =
default;
176 template <
typename T>
177 T
at(uint64_t index)
const {
178 auto *valueIt =
static_cast<OpaqueIteratorValueBase<T> *
>(iterator.get());
179 return valueIt->at(index);
183 std::unique_ptr<OpaqueIteratorBase> iterator;
198 template <
typename T>
200 :
public llvm::iterator_facade_base<ElementsAttrIterator<T>,
201 std::random_access_iterator_tag, T,
202 std::ptrdiff_t, T, T> {
205 : indexer(std::move(indexer)), index(dataIndex) {}
209 return index - rhs.index;
212 return index == rhs.index;
215 return index < rhs.index;
235 template <
typename IteratorT>
249 return *std::next(this->begin(), index);
253 size_t size()
const {
return llvm::size(*
this); }
257 ShapedType shapeType;
282 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
291 template <
typename IteratorT>
295 return (*
this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
300 template <
typename T>
301 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
302 if (std::optional<iterator<T>> iterator = try_value_begin<T>())
303 return std::move(*iterator);
305 <<
"ElementsAttr does not provide iteration facilities for type `"
306 << llvm::getTypeName<T>() <<
"`, see attribute: " << *
this <<
"\n";
307 llvm_unreachable(
"invalid `T` for ElementsAttr::getValues");
309 template <
typename T>
310 auto ElementsAttr::try_value_begin() const
311 -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
312 FailureOr<detail::ElementsAttrIndexer> indexer =
313 getValuesImpl(TypeID::get<T>());
316 return iterator<T>(std::move(*indexer), 0);
This class implements a generic iterator for ElementsAttr.
ptrdiff_t operator-(const ElementsAttrIterator &rhs) const
T operator*() const
Return the value at the current iterator position.
ElementsAttrIterator & operator+=(ptrdiff_t offset)
bool operator<(const ElementsAttrIterator &rhs) const
ElementsAttrIterator & operator-=(ptrdiff_t offset)
ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
bool operator==(const ElementsAttrIterator &rhs) const
This class provides iterator utilities for an ElementsAttr range.
typename IteratorT::reference reference
reference operator[](ArrayRef< uint64_t > index) const
Return the value at the given index.
size_t size() const
Return the size of this range.
reference operator[](uint64_t index) const
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range< IteratorT > &range)
Include the generated interface declarations.
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class provides support for indexing into the element range of an ElementsAttr.
static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr)
static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator)
Construct an indexer for a non-contiguous range starting at the given iterator.
T at(uint64_t index) const
Access the element at the given index.
NonContiguousState nonConState
ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
ElementsAttrIndexer(ElementsAttrIndexer &&rhs)