9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
17 #include "llvm/ADT/Any.h"
18 #include "llvm/Support/raw_ostream.h"
40 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
44 new (&
nonConState) NonContiguousState(std::move(rhs.nonConState));
47 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
61 template <
typename IteratorT>
65 NonContiguousState(std::forward<IteratorT>(iterator));
75 new (&indexer.
conState) ContiguousState(firstEltPtr);
81 T
at(uint64_t index)
const {
89 : isContiguous(isContiguous), isSplat(isSplat),
conState(nullptr) {}
93 class ContiguousState {
95 ContiguousState(
const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
99 const T &
at(uint64_t index)
const {
100 return *(
reinterpret_cast<const T *
>(firstEltPtr) + index);
104 const void *firstEltPtr;
109 class NonContiguousState {
114 struct OpaqueIteratorBase {
115 virtual ~OpaqueIteratorBase() =
default;
116 virtual std::unique_ptr<OpaqueIteratorBase>
clone()
const = 0;
121 template <
typename T>
122 struct OpaqueIteratorValueBase :
public OpaqueIteratorBase {
123 virtual T
at(uint64_t index) = 0;
127 template <
typename IteratorT,
typename T>
128 struct OpaqueIterator :
public OpaqueIteratorValueBase<T> {
129 template <
typename ItTy,
typename FuncTy,
typename FuncReturnTy>
130 static void isMappedIteratorTestFn(
131 llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
132 template <
typename U,
typename... Args>
133 using is_mapped_iterator =
134 decltype(isMappedIteratorTestFn(std::declval<U>()));
135 template <
typename U>
136 using detect_is_mapped_iterator =
137 llvm::is_detected<is_mapped_iterator, U>;
140 template <
typename ItT>
141 static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
142 atImpl(ItT &&it, uint64_t index) {
143 return *std::next(it, index);
145 template <
typename ItT>
146 static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
147 atImpl(ItT &&it, uint64_t index) {
149 return it.getFunction()(*std::next(it.getCurrent(), index));
153 template <
typename U>
154 OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
155 std::unique_ptr<OpaqueIteratorBase>
clone() const final {
156 return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
160 T
at(uint64_t index)
final {
return atImpl(iterator, index); }
168 template <
typename IteratorT,
typename T =
typename llvm::remove_cvref_t<
169 decltype(*std::declval<IteratorT>())>>
170 NonContiguousState(IteratorT iterator)
171 : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
172 NonContiguousState(
const NonContiguousState &other)
173 : iterator(other.iterator->
clone()) {}
174 NonContiguousState(NonContiguousState &&other) =
default;
177 template <
typename T>
178 T
at(uint64_t index)
const {
179 auto *valueIt =
static_cast<OpaqueIteratorValueBase<T> *
>(iterator.get());
180 return valueIt->at(index);
184 std::unique_ptr<OpaqueIteratorBase> iterator;
199 template <
typename T>
201 :
public llvm::iterator_facade_base<ElementsAttrIterator<T>,
202 std::random_access_iterator_tag, T,
203 std::ptrdiff_t, T, T> {
206 : indexer(std::move(indexer)), index(dataIndex) {}
210 return index - rhs.index;
213 return index == rhs.index;
216 return index < rhs.index;
236 template <
typename IteratorT>
250 return *std::next(this->begin(), index);
254 size_t size()
const {
return llvm::size(*
this); }
258 ShapedType shapeType;
283 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
292 template <
typename IteratorT>
296 return (*
this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
301 template <
typename T>
302 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
303 if (std::optional<iterator<T>> iterator = try_value_begin<T>())
304 return std::move(*iterator);
306 <<
"ElementsAttr does not provide iteration facilities for type `"
307 << llvm::getTypeName<T>() <<
"`, see attribute: " << *
this <<
"\n";
308 llvm_unreachable(
"invalid `T` for ElementsAttr::getValues");
310 template <
typename T>
311 auto ElementsAttr::try_value_begin() const
312 -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
313 FailureOr<detail::ElementsAttrIndexer> indexer =
314 getValuesImpl(TypeID::get<T>());
317 return iterator<T>(std::move(*indexer), 0);
This class implements a generic iterator for ElementsAttr.
ptrdiff_t operator-(const ElementsAttrIterator &rhs) const
T operator*() const
Return the value at the current iterator position.
ElementsAttrIterator & operator+=(ptrdiff_t offset)
bool operator<(const ElementsAttrIterator &rhs) const
ElementsAttrIterator & operator-=(ptrdiff_t offset)
ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
bool operator==(const ElementsAttrIterator &rhs) const
This class provides iterator utilities for an ElementsAttr range.
typename IteratorT::reference reference
reference operator[](ArrayRef< uint64_t > index) const
Return the value at the given index.
size_t size() const
Return the size of this range.
reference operator[](uint64_t index) const
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range< IteratorT > &range)
Include the generated interface declarations.
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class provides support for indexing into the element range of an ElementsAttr.
static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr)
static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator)
Construct an indexer for a non-contiguous range starting at the given iterator.
T at(uint64_t index) const
Access the element at the given index.
NonContiguousState nonConState
ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
ElementsAttrIndexer(ElementsAttrIndexer &&rhs)