MLIR  19.0.0git
BuiltinAttributeInterfaces.h
Go to the documentation of this file.
1 //===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
11 
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Types.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <complex>
19 #include <optional>
20 
21 namespace mlir {
22 
23 //===----------------------------------------------------------------------===//
24 // ElementsAttr
25 //===----------------------------------------------------------------------===//
26 namespace detail {
27 /// This class provides support for indexing into the element range of an
28 /// ElementsAttr. It is used to opaquely wrap either a contiguous range, via
29 /// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via
30 /// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like
31 /// range, where all of the elements are layed out sequentially in memory. A
32 /// non-contiguous range implies no contiguity, and elements may even be
33 /// materialized when indexing, such as the case for a mapped_range.
35 public:
37  : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {}
39  : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
40  if (isContiguous)
41  conState = rhs.conState;
42  else
43  new (&nonConState) NonContiguousState(std::move(rhs.nonConState));
44  }
46  : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
47  if (isContiguous)
48  conState = rhs.conState;
49  else
50  new (&nonConState) NonContiguousState(rhs.nonConState);
51  }
53  if (!isContiguous)
54  nonConState.~NonContiguousState();
55  }
56 
57  /// Construct an indexer for a non-contiguous range starting at the given
58  /// iterator. A non-contiguous range implies no contiguity, and elements may
59  /// even be materialized when indexing, such as the case for a mapped_range.
60  template <typename IteratorT>
61  static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) {
62  ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat);
63  new (&indexer.nonConState)
64  NonContiguousState(std::forward<IteratorT>(iterator));
65  return indexer;
66  }
67 
68  // Construct an indexer for a contiguous range starting at the given element
69  // pointer. A contiguous range is an array-like range, where all of the
70  // elements are layed out sequentially in memory.
71  template <typename T>
72  static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) {
73  ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat);
74  new (&indexer.conState) ContiguousState(firstEltPtr);
75  return indexer;
76  }
77 
78  /// Access the element at the given index.
79  template <typename T>
80  T at(uint64_t index) const {
81  if (isSplat)
82  index = 0;
83  return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
84  }
85 
86 private:
87  ElementsAttrIndexer(bool isContiguous, bool isSplat)
88  : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {}
89 
90  /// This class contains all of the state necessary to index a contiguous
91  /// range.
92  class ContiguousState {
93  public:
94  ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
95 
96  /// Access the element at the given index.
97  template <typename T>
98  const T &at(uint64_t index) const {
99  return *(reinterpret_cast<const T *>(firstEltPtr) + index);
100  }
101 
102  private:
103  const void *firstEltPtr;
104  };
105 
106  /// This class contains all of the state necessary to index a non-contiguous
107  /// range.
108  class NonContiguousState {
109  private:
110  /// This class is used to represent the abstract base of an opaque iterator.
111  /// This allows for all iterator and element types to be completely
112  /// type-erased.
113  struct OpaqueIteratorBase {
114  virtual ~OpaqueIteratorBase() = default;
115  virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0;
116  };
117  /// This class is used to represent the abstract base of an opaque iterator
118  /// that iterates over elements of type `T`. This allows for all iterator
119  /// types to be completely type-erased.
120  template <typename T>
121  struct OpaqueIteratorValueBase : public OpaqueIteratorBase {
122  virtual T at(uint64_t index) = 0;
123  };
124  /// This class is used to represent an opaque handle to an iterator of type
125  /// `IteratorT` that iterates over elements of type `T`.
126  template <typename IteratorT, typename T>
127  struct OpaqueIterator : public OpaqueIteratorValueBase<T> {
128  template <typename ItTy, typename FuncTy, typename FuncReturnTy>
129  static void isMappedIteratorTestFn(
130  llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
131  template <typename U, typename... Args>
132  using is_mapped_iterator =
133  decltype(isMappedIteratorTestFn(std::declval<U>()));
134  template <typename U>
135  using detect_is_mapped_iterator =
136  llvm::is_detected<is_mapped_iterator, U>;
137 
138  /// Access the element within the iterator at the given index.
139  template <typename ItT>
140  static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
141  atImpl(ItT &&it, uint64_t index) {
142  return *std::next(it, index);
143  }
144  template <typename ItT>
145  static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
146  atImpl(ItT &&it, uint64_t index) {
147  // Special case mapped_iterator to avoid copying the function.
148  return it.getFunction()(*std::next(it.getCurrent(), index));
149  }
150 
151  public:
152  template <typename U>
153  OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
154  std::unique_ptr<OpaqueIteratorBase> clone() const final {
155  return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
156  }
157 
158  /// Access the element at the given index.
159  T at(uint64_t index) final { return atImpl(iterator, index); }
160 
161  private:
162  IteratorT iterator;
163  };
164 
165  public:
166  /// Construct the state with the given iterator type.
167  template <typename IteratorT, typename T = typename llvm::remove_cvref_t<
168  decltype(*std::declval<IteratorT>())>>
169  NonContiguousState(IteratorT iterator)
170  : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
171  NonContiguousState(const NonContiguousState &other)
172  : iterator(other.iterator->clone()) {}
173  NonContiguousState(NonContiguousState &&other) = default;
174 
175  /// Access the element at the given index.
176  template <typename T>
177  T at(uint64_t index) const {
178  auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
179  return valueIt->at(index);
180  }
181 
182  /// The opaque iterator state.
183  std::unique_ptr<OpaqueIteratorBase> iterator;
184  };
185 
186  /// A boolean indicating if this range is contiguous or not.
187  bool isContiguous;
188  /// A boolean indicating if this range is a splat.
189  bool isSplat;
190  /// The underlying range state.
191  union {
192  ContiguousState conState;
193  NonContiguousState nonConState;
194  };
195 };
196 
197 /// This class implements a generic iterator for ElementsAttr.
198 template <typename T>
200  : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
201  std::random_access_iterator_tag, T,
202  std::ptrdiff_t, T, T> {
203 public:
204  ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
205  : indexer(std::move(indexer)), index(dataIndex) {}
206 
207  // Boilerplate iterator methods.
208  ptrdiff_t operator-(const ElementsAttrIterator &rhs) const {
209  return index - rhs.index;
210  }
211  bool operator==(const ElementsAttrIterator &rhs) const {
212  return index == rhs.index;
213  }
214  bool operator<(const ElementsAttrIterator &rhs) const {
215  return index < rhs.index;
216  }
217  ElementsAttrIterator &operator+=(ptrdiff_t offset) {
218  index += offset;
219  return *this;
220  }
221  ElementsAttrIterator &operator-=(ptrdiff_t offset) {
222  index -= offset;
223  return *this;
224  }
225 
226  /// Return the value at the current iterator position.
227  T operator*() const { return indexer.at<T>(index); }
228 
229 private:
230  ElementsAttrIndexer indexer;
231  ptrdiff_t index;
232 };
233 
234 /// This class provides iterator utilities for an ElementsAttr range.
235 template <typename IteratorT>
236 class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
237 public:
238  using reference = typename IteratorT::reference;
239 
240  ElementsAttrRange(ShapedType shapeType,
241  const llvm::iterator_range<IteratorT> &range)
242  : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
243  ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
244  : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
245 
246  /// Return the value at the given index.
248  reference operator[](uint64_t index) const {
249  return *std::next(this->begin(), index);
250  }
251 
252  /// Return the size of this range.
253  size_t size() const { return llvm::size(*this); }
254 
255 private:
256  /// The shaped type of the parent ElementsAttr.
257  ShapedType shapeType;
258 };
259 
260 } // namespace detail
261 
262 //===----------------------------------------------------------------------===//
263 // MemRefLayoutAttrInterface
264 //===----------------------------------------------------------------------===//
265 
266 namespace detail {
267 
268 // Verify the affine map 'm' can be used as a layout specification
269 // for memref with 'shape'.
270 LogicalResult
271 verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
272  function_ref<InFlightDiagnostic()> emitError);
273 
274 } // namespace detail
275 
276 } // namespace mlir
277 
278 //===----------------------------------------------------------------------===//
279 // Tablegen Interface Declarations
280 //===----------------------------------------------------------------------===//
281 
282 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
283 
284 //===----------------------------------------------------------------------===//
285 // ElementsAttr
286 //===----------------------------------------------------------------------===//
287 
288 namespace mlir {
289 namespace detail {
290 /// Return the value at the given index.
291 template <typename IteratorT>
293  -> reference {
294  // Skip to the element corresponding to the flattened index.
295  return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
296 }
297 } // namespace detail
298 
299 /// Return the elements of this attribute as a value of type 'T'.
300 template <typename T>
301 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
302  if (std::optional<iterator<T>> iterator = try_value_begin<T>())
303  return std::move(*iterator);
304  llvm::errs()
305  << "ElementsAttr does not provide iteration facilities for type `"
306  << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n";
307  llvm_unreachable("invalid `T` for ElementsAttr::getValues");
308 }
309 template <typename T>
310 auto ElementsAttr::try_value_begin() const
311  -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
312  FailureOr<detail::ElementsAttrIndexer> indexer =
313  getValuesImpl(TypeID::get<T>());
314  if (failed(indexer))
315  return std::nullopt;
316  return iterator<T>(std::move(*indexer), 0);
317 }
318 } // namespace mlir.
319 
320 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
This class implements a generic iterator for ElementsAttr.
ptrdiff_t operator-(const ElementsAttrIterator &rhs) const
T operator*() const
Return the value at the current iterator position.
ElementsAttrIterator & operator+=(ptrdiff_t offset)
bool operator<(const ElementsAttrIterator &rhs) const
ElementsAttrIterator & operator-=(ptrdiff_t offset)
ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
bool operator==(const ElementsAttrIterator &rhs) const
This class provides iterator utilities for an ElementsAttr range.
typename IteratorT::reference reference
reference operator[](ArrayRef< uint64_t > index) const
Return the value at the given index.
size_t size() const
Return the size of this range.
reference operator[](uint64_t index) const
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range< IteratorT > &range)
Include the generated interface declarations.
Definition: CallGraph.h:229
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:147
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class provides support for indexing into the element range of an ElementsAttr.
static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr)
static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator)
Construct an indexer for a non-contiguous range starting at the given iterator.
T at(uint64_t index) const
Access the element at the given index.
ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
ElementsAttrIndexer(ElementsAttrIndexer &&rhs)