9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
16 #include "llvm/Support/raw_ostream.h"
38 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
42 new (&
nonConState) NonContiguousState(std::move(rhs.nonConState));
45 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
59 template <
typename IteratorT>
63 NonContiguousState(std::forward<IteratorT>(iterator));
73 new (&indexer.
conState) ContiguousState(firstEltPtr);
79 T
at(uint64_t index)
const {
87 : isContiguous(isContiguous), isSplat(isSplat),
conState(nullptr) {}
91 class ContiguousState {
93 ContiguousState(
const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
97 const T &
at(uint64_t index)
const {
98 return *(
reinterpret_cast<const T *
>(firstEltPtr) + index);
102 const void *firstEltPtr;
107 class NonContiguousState {
112 struct OpaqueIteratorBase {
113 virtual ~OpaqueIteratorBase() =
default;
114 virtual std::unique_ptr<OpaqueIteratorBase>
clone()
const = 0;
119 template <
typename T>
120 struct OpaqueIteratorValueBase :
public OpaqueIteratorBase {
121 virtual T
at(uint64_t index) = 0;
125 template <
typename IteratorT,
typename T>
126 struct OpaqueIterator :
public OpaqueIteratorValueBase<T> {
127 template <
typename ItTy,
typename FuncTy,
typename FuncReturnTy>
128 static void isMappedIteratorTestFn(
129 llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
130 template <
typename U,
typename... Args>
131 using is_mapped_iterator =
132 decltype(isMappedIteratorTestFn(std::declval<U>()));
133 template <
typename U>
134 using detect_is_mapped_iterator =
135 llvm::is_detected<is_mapped_iterator, U>;
138 template <
typename ItT>
139 static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
140 atImpl(ItT &&it, uint64_t index) {
141 return *std::next(it, index);
143 template <
typename ItT>
144 static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
145 atImpl(ItT &&it, uint64_t index) {
147 return it.getFunction()(*std::next(it.getCurrent(), index));
151 template <
typename U>
152 OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
153 std::unique_ptr<OpaqueIteratorBase>
clone() const final {
154 return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
158 T
at(uint64_t index)
final {
return atImpl(iterator, index); }
166 template <
typename IteratorT,
typename T =
typename llvm::remove_cvref_t<
167 decltype(*std::declval<IteratorT>())>>
168 NonContiguousState(IteratorT iterator)
169 : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
170 NonContiguousState(
const NonContiguousState &other)
171 : iterator(other.iterator->
clone()) {}
172 NonContiguousState(NonContiguousState &&other) =
default;
175 template <
typename T>
176 T
at(uint64_t index)
const {
177 auto *valueIt =
static_cast<OpaqueIteratorValueBase<T> *
>(iterator.get());
178 return valueIt->at(index);
182 std::unique_ptr<OpaqueIteratorBase> iterator;
197 template <
typename T>
199 :
public llvm::iterator_facade_base<ElementsAttrIterator<T>,
200 std::random_access_iterator_tag, T,
201 std::ptrdiff_t, T, T> {
204 : indexer(std::move(indexer)), index(dataIndex) {}
208 return index - rhs.index;
211 return index == rhs.index;
214 return index < rhs.index;
234 template <
typename IteratorT>
248 return *std::next(this->begin(), index);
252 size_t size()
const {
return llvm::size(*
this); }
256 ShapedType shapeType;
281 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
290 template <
typename IteratorT>
294 return (*
this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
299 template <
typename T>
300 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
301 if (std::optional<iterator<T>> iterator = try_value_begin<T>())
302 return std::move(*iterator);
304 <<
"ElementsAttr does not provide iteration facilities for type `"
305 << llvm::getTypeName<T>() <<
"`, see attribute: " << *
this <<
"\n";
306 llvm_unreachable(
"invalid `T` for ElementsAttr::getValues");
308 template <
typename T>
309 auto ElementsAttr::try_value_begin() const
310 -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
311 FailureOr<detail::ElementsAttrIndexer> indexer =
312 getValuesImpl(TypeID::get<T>());
315 return iterator<T>(std::move(*indexer), 0);
This class implements a generic iterator for ElementsAttr.
ptrdiff_t operator-(const ElementsAttrIterator &rhs) const
T operator*() const
Return the value at the current iterator position.
ElementsAttrIterator & operator+=(ptrdiff_t offset)
bool operator<(const ElementsAttrIterator &rhs) const
ElementsAttrIterator & operator-=(ptrdiff_t offset)
ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
bool operator==(const ElementsAttrIterator &rhs) const
This class provides iterator utilities for an ElementsAttr range.
typename IteratorT::reference reference
reference operator[](ArrayRef< uint64_t > index) const
Return the value at the given index.
size_t size() const
Return the size of this range.
reference operator[](uint64_t index) const
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range< IteratorT > &range)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
This class provides support for indexing into the element range of an ElementsAttr.
static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr)
static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator)
Construct an indexer for a non-contiguous range starting at the given iterator.
T at(uint64_t index) const
Access the element at the given index.
NonContiguousState nonConState
ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
ElementsAttrIndexer(ElementsAttrIndexer &&rhs)