MLIR 22.0.0git
BuiltinAttributeInterfaces.h
Go to the documentation of this file.
1//===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
11
12#include "mlir/IR/AffineMap.h"
13#include "mlir/IR/Attributes.h"
15#include "mlir/IR/Types.h"
16#include "llvm/Support/raw_ostream.h"
17#include <complex>
18#include <optional>
19
20namespace mlir {
21
22//===----------------------------------------------------------------------===//
23// ElementsAttr
24//===----------------------------------------------------------------------===//
25namespace detail {
26/// This class provides support for indexing into the element range of an
27/// ElementsAttr. It is used to opaquely wrap either a contiguous range, via
28/// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via
29/// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like
30/// range, where all of the elements are layed out sequentially in memory. A
31/// non-contiguous range implies no contiguity, and elements may even be
32/// materialized when indexing, such as the case for a mapped_range.
34public:
36 : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {}
38 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
39 if (isContiguous)
40 conState = rhs.conState;
41 else
42 new (&nonConState) NonContiguousState(std::move(rhs.nonConState));
43 }
45 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
46 if (isContiguous)
47 conState = rhs.conState;
48 else
49 new (&nonConState) NonContiguousState(rhs.nonConState);
50 }
52 if (!isContiguous)
53 nonConState.~NonContiguousState();
54 }
55
56 /// Construct an indexer for a non-contiguous range starting at the given
57 /// iterator. A non-contiguous range implies no contiguity, and elements may
58 /// even be materialized when indexing, such as the case for a mapped_range.
59 template <typename IteratorT>
60 static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) {
61 ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat);
62 new (&indexer.nonConState)
63 NonContiguousState(std::forward<IteratorT>(iterator));
64 return indexer;
65 }
66
67 // Construct an indexer for a contiguous range starting at the given element
68 // pointer. A contiguous range is an array-like range, where all of the
69 // elements are layed out sequentially in memory.
70 template <typename T>
71 static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) {
72 ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat);
73 new (&indexer.conState) ContiguousState(firstEltPtr);
74 return indexer;
75 }
76
77 /// Access the element at the given index.
78 template <typename T>
79 T at(uint64_t index) const {
80 if (isSplat)
81 index = 0;
82 return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
83 }
84
85private:
86 ElementsAttrIndexer(bool isContiguous, bool isSplat)
87 : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {}
88
89 /// This class contains all of the state necessary to index a contiguous
90 /// range.
91 class ContiguousState {
92 public:
93 ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
94
95 /// Access the element at the given index.
96 template <typename T>
97 const T &at(uint64_t index) const {
98 return *(reinterpret_cast<const T *>(firstEltPtr) + index);
99 }
100
101 private:
102 const void *firstEltPtr;
103 };
104
105 /// This class contains all of the state necessary to index a non-contiguous
106 /// range.
107 class NonContiguousState {
108 private:
109 /// This class is used to represent the abstract base of an opaque iterator.
110 /// This allows for all iterator and element types to be completely
111 /// type-erased.
112 struct OpaqueIteratorBase {
113 virtual ~OpaqueIteratorBase() = default;
114 virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0;
115 };
116 /// This class is used to represent the abstract base of an opaque iterator
117 /// that iterates over elements of type `T`. This allows for all iterator
118 /// types to be completely type-erased.
119 template <typename T>
120 struct OpaqueIteratorValueBase : public OpaqueIteratorBase {
121 virtual T at(uint64_t index) = 0;
122 };
123 /// This class is used to represent an opaque handle to an iterator of type
124 /// `IteratorT` that iterates over elements of type `T`.
125 template <typename IteratorT, typename T>
126 struct OpaqueIterator : public OpaqueIteratorValueBase<T> {
127 template <typename ItTy, typename FuncTy, typename FuncReturnTy>
128 static void isMappedIteratorTestFn(
129 llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
130 template <typename U, typename... Args>
131 using is_mapped_iterator =
132 decltype(isMappedIteratorTestFn(std::declval<U>()));
133 template <typename U>
134 using detect_is_mapped_iterator =
135 llvm::is_detected<is_mapped_iterator, U>;
136
137 /// Access the element within the iterator at the given index.
138 template <typename ItT>
139 static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
140 atImpl(ItT &&it, uint64_t index) {
141 return *std::next(it, index);
142 }
143 template <typename ItT>
144 static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
145 atImpl(ItT &&it, uint64_t index) {
146 // Special case mapped_iterator to avoid copying the function.
147 return it.getFunction()(*std::next(it.getCurrent(), index));
148 }
149
150 public:
151 template <typename U>
152 OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
153 std::unique_ptr<OpaqueIteratorBase> clone() const final {
154 return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
155 }
156
157 /// Access the element at the given index.
158 T at(uint64_t index) final { return atImpl(iterator, index); }
159
160 private:
161 IteratorT iterator;
162 };
163
164 public:
165 /// Construct the state with the given iterator type.
166 template <typename IteratorT, typename T = typename llvm::remove_cvref_t<
167 decltype(*std::declval<IteratorT>())>>
168 NonContiguousState(IteratorT iterator)
169 : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
170 NonContiguousState(const NonContiguousState &other)
171 : iterator(other.iterator->clone()) {}
172 NonContiguousState(NonContiguousState &&other) = default;
173
174 /// Access the element at the given index.
175 template <typename T>
176 T at(uint64_t index) const {
177 auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
178 return valueIt->at(index);
179 }
180
181 /// The opaque iterator state.
182 std::unique_ptr<OpaqueIteratorBase> iterator;
183 };
184
185 /// A boolean indicating if this range is contiguous or not.
186 bool isContiguous;
187 /// A boolean indicating if this range is a splat.
188 bool isSplat;
189 /// The underlying range state.
190 union {
191 ContiguousState conState;
192 NonContiguousState nonConState;
193 };
194};
195
196/// This class implements a generic iterator for ElementsAttr.
197template <typename T>
199 : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
200 std::random_access_iterator_tag, T,
201 std::ptrdiff_t, T, T> {
202public:
203 ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
204 : indexer(std::move(indexer)), index(dataIndex) {}
205
206 // Boilerplate iterator methods.
207 ptrdiff_t operator-(const ElementsAttrIterator &rhs) const {
208 return index - rhs.index;
209 }
211 return index == rhs.index;
212 }
213 bool operator<(const ElementsAttrIterator &rhs) const {
214 return index < rhs.index;
215 }
216 ElementsAttrIterator &operator+=(ptrdiff_t offset) {
217 index += offset;
218 return *this;
219 }
220 ElementsAttrIterator &operator-=(ptrdiff_t offset) {
221 index -= offset;
222 return *this;
223 }
224
225 /// Return the value at the current iterator position.
226 T operator*() const { return indexer.at<T>(index); }
227
228private:
229 ElementsAttrIndexer indexer;
230 ptrdiff_t index;
231};
232
233/// This class provides iterator utilities for an ElementsAttr range.
234template <typename IteratorT>
235class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
236public:
237 using reference = typename IteratorT::reference;
238
239 ElementsAttrRange(ShapedType shapeType,
241 : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
242 ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
243 : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
244
245 /// Return the value at the given index.
246 reference operator[](ArrayRef<uint64_t> index) const;
247 reference operator[](uint64_t index) const {
248 return *std::next(this->begin(), index);
249 }
250
251 /// Return the size of this range.
252 size_t size() const { return llvm::size(*this); }
253
254private:
255 /// The shaped type of the parent ElementsAttr.
256 ShapedType shapeType;
257};
258
259} // namespace detail
260
261//===----------------------------------------------------------------------===//
262// MemRefLayoutAttrInterface
263//===----------------------------------------------------------------------===//
264
265namespace detail {
266
267// Verify the affine map 'm' can be used as a layout specification
268// for memref with 'shape'.
269LogicalResult
270verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
271 function_ref<InFlightDiagnostic()> emitError);
272
273// Return the strides and offsets that can be inferred from the given affine
274// layout map given the map and a memref shape.
275LogicalResult getAffineMapStridesAndOffset(AffineMap map,
276 ArrayRef<int64_t> shape,
277 SmallVectorImpl<int64_t> &strides,
278 int64_t &offset);
279} // namespace detail
280
281} // namespace mlir
282
283//===----------------------------------------------------------------------===//
284// Tablegen Interface Declarations
285//===----------------------------------------------------------------------===//
286
287#include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
288#include "mlir/IR/OpAsmAttrInterface.h.inc"
289
290//===----------------------------------------------------------------------===//
291// ElementsAttr
292//===----------------------------------------------------------------------===//
293
294namespace mlir {
295namespace detail {
296/// Return the value at the given index.
297template <typename IteratorT>
299 -> reference {
300 // Skip to the element corresponding to the flattened index.
301 return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
302}
303} // namespace detail
304
305/// Return the elements of this attribute as a value of type 'T'.
306template <typename T>
307auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
308 if (std::optional<iterator<T>> iterator = try_value_begin<T>())
309 return std::move(*iterator);
310 llvm::errs()
311 << "ElementsAttr does not provide iteration facilities for type `"
312 << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n";
313 llvm_unreachable("invalid `T` for ElementsAttr::getValues");
314}
315template <typename T>
316auto ElementsAttr::try_value_begin() const
317 -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
318 FailureOr<detail::ElementsAttrIndexer> indexer =
319 getValuesImpl(TypeID::get<T>());
320 if (failed(indexer))
321 return std::nullopt;
322 return iterator<T>(std::move(*indexer), 0);
323}
324} // namespace mlir.
325
326#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
true
Given two iterators into the same block, return "true" if a is before `b.
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
ptrdiff_t operator-(const ElementsAttrIterator &rhs) const
T operator*() const
Return the value at the current iterator position.
bool operator<(const ElementsAttrIterator &rhs) const
ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
ElementsAttrIterator & operator-=(ptrdiff_t offset)
ElementsAttrIterator & operator+=(ptrdiff_t offset)
bool operator==(const ElementsAttrIterator &rhs) const
typename IteratorT::reference reference
reference operator[](ArrayRef< uint64_t > index) const
Return the value at the given index.
size_t size() const
Return the size of this range.
reference operator[](uint64_t index) const
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range< IteratorT > &range)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
AttrTypeReplacer.
Include the generated interface declarations.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
This class provides support for indexing into the element range of an ElementsAttr.
static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr)
static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator)
Construct an indexer for a non-contiguous range starting at the given iterator.
T at(uint64_t index) const
Access the element at the given index.
ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
ElementsAttrIndexer(ElementsAttrIndexer &&rhs)