MLIR  17.0.0git
BuiltinAttributeInterfaces.h
Go to the documentation of this file.
1 //===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
11 
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/Any.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include <complex>
20 #include <optional>
21 
22 namespace mlir {
23 
24 //===----------------------------------------------------------------------===//
25 // ElementsAttr
26 //===----------------------------------------------------------------------===//
27 namespace detail {
28 /// This class provides support for indexing into the element range of an
29 /// ElementsAttr. It is used to opaquely wrap either a contiguous range, via
30 /// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via
31 /// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like
32 /// range, where all of the elements are layed out sequentially in memory. A
33 /// non-contiguous range implies no contiguity, and elements may even be
34 /// materialized when indexing, such as the case for a mapped_range.
36 public:
38  : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {}
40  : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
41  if (isContiguous)
42  conState = rhs.conState;
43  else
44  new (&nonConState) NonContiguousState(std::move(rhs.nonConState));
45  }
47  : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
48  if (isContiguous)
49  conState = rhs.conState;
50  else
51  new (&nonConState) NonContiguousState(rhs.nonConState);
52  }
54  if (!isContiguous)
55  nonConState.~NonContiguousState();
56  }
57 
58  /// Construct an indexer for a non-contiguous range starting at the given
59  /// iterator. A non-contiguous range implies no contiguity, and elements may
60  /// even be materialized when indexing, such as the case for a mapped_range.
61  template <typename IteratorT>
62  static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) {
63  ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat);
64  new (&indexer.nonConState)
65  NonContiguousState(std::forward<IteratorT>(iterator));
66  return indexer;
67  }
68 
69  // Construct an indexer for a contiguous range starting at the given element
70  // pointer. A contiguous range is an array-like range, where all of the
71  // elements are layed out sequentially in memory.
72  template <typename T>
73  static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) {
74  ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat);
75  new (&indexer.conState) ContiguousState(firstEltPtr);
76  return indexer;
77  }
78 
79  /// Access the element at the given index.
80  template <typename T>
81  T at(uint64_t index) const {
82  if (isSplat)
83  index = 0;
84  return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
85  }
86 
87 private:
88  ElementsAttrIndexer(bool isContiguous, bool isSplat)
89  : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {}
90 
91  /// This class contains all of the state necessary to index a contiguous
92  /// range.
93  class ContiguousState {
94  public:
95  ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
96 
97  /// Access the element at the given index.
98  template <typename T>
99  const T &at(uint64_t index) const {
100  return *(reinterpret_cast<const T *>(firstEltPtr) + index);
101  }
102 
103  private:
104  const void *firstEltPtr;
105  };
106 
107  /// This class contains all of the state necessary to index a non-contiguous
108  /// range.
109  class NonContiguousState {
110  private:
111  /// This class is used to represent the abstract base of an opaque iterator.
112  /// This allows for all iterator and element types to be completely
113  /// type-erased.
114  struct OpaqueIteratorBase {
115  virtual ~OpaqueIteratorBase() = default;
116  virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0;
117  };
118  /// This class is used to represent the abstract base of an opaque iterator
119  /// that iterates over elements of type `T`. This allows for all iterator
120  /// types to be completely type-erased.
121  template <typename T>
122  struct OpaqueIteratorValueBase : public OpaqueIteratorBase {
123  virtual T at(uint64_t index) = 0;
124  };
125  /// This class is used to represent an opaque handle to an iterator of type
126  /// `IteratorT` that iterates over elements of type `T`.
127  template <typename IteratorT, typename T>
128  struct OpaqueIterator : public OpaqueIteratorValueBase<T> {
129  template <typename ItTy, typename FuncTy, typename FuncReturnTy>
130  static void isMappedIteratorTestFn(
131  llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
132  template <typename U, typename... Args>
133  using is_mapped_iterator =
134  decltype(isMappedIteratorTestFn(std::declval<U>()));
135  template <typename U>
136  using detect_is_mapped_iterator =
137  llvm::is_detected<is_mapped_iterator, U>;
138 
139  /// Access the element within the iterator at the given index.
140  template <typename ItT>
141  static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
142  atImpl(ItT &&it, uint64_t index) {
143  return *std::next(it, index);
144  }
145  template <typename ItT>
146  static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
147  atImpl(ItT &&it, uint64_t index) {
148  // Special case mapped_iterator to avoid copying the function.
149  return it.getFunction()(*std::next(it.getCurrent(), index));
150  }
151 
152  public:
153  template <typename U>
154  OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
155  std::unique_ptr<OpaqueIteratorBase> clone() const final {
156  return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
157  }
158 
159  /// Access the element at the given index.
160  T at(uint64_t index) final { return atImpl(iterator, index); }
161 
162  private:
163  IteratorT iterator;
164  };
165 
166  public:
167  /// Construct the state with the given iterator type.
168  template <typename IteratorT, typename T = typename llvm::remove_cvref_t<
169  decltype(*std::declval<IteratorT>())>>
170  NonContiguousState(IteratorT iterator)
171  : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
172  NonContiguousState(const NonContiguousState &other)
173  : iterator(other.iterator->clone()) {}
174  NonContiguousState(NonContiguousState &&other) = default;
175 
176  /// Access the element at the given index.
177  template <typename T>
178  T at(uint64_t index) const {
179  auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
180  return valueIt->at(index);
181  }
182 
183  /// The opaque iterator state.
184  std::unique_ptr<OpaqueIteratorBase> iterator;
185  };
186 
187  /// A boolean indicating if this range is contiguous or not.
188  bool isContiguous;
189  /// A boolean indicating if this range is a splat.
190  bool isSplat;
191  /// The underlying range state.
192  union {
193  ContiguousState conState;
194  NonContiguousState nonConState;
195  };
196 };
197 
198 /// This class implements a generic iterator for ElementsAttr.
199 template <typename T>
201  : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
202  std::random_access_iterator_tag, T,
203  std::ptrdiff_t, T, T> {
204 public:
205  ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
206  : indexer(std::move(indexer)), index(dataIndex) {}
207 
208  // Boilerplate iterator methods.
209  ptrdiff_t operator-(const ElementsAttrIterator &rhs) const {
210  return index - rhs.index;
211  }
212  bool operator==(const ElementsAttrIterator &rhs) const {
213  return index == rhs.index;
214  }
215  bool operator<(const ElementsAttrIterator &rhs) const {
216  return index < rhs.index;
217  }
218  ElementsAttrIterator &operator+=(ptrdiff_t offset) {
219  index += offset;
220  return *this;
221  }
222  ElementsAttrIterator &operator-=(ptrdiff_t offset) {
223  index -= offset;
224  return *this;
225  }
226 
227  /// Return the value at the current iterator position.
228  T operator*() const { return indexer.at<T>(index); }
229 
230 private:
231  ElementsAttrIndexer indexer;
232  ptrdiff_t index;
233 };
234 
235 /// This class provides iterator utilities for an ElementsAttr range.
236 template <typename IteratorT>
237 class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
238 public:
239  using reference = typename IteratorT::reference;
240 
241  ElementsAttrRange(ShapedType shapeType,
242  const llvm::iterator_range<IteratorT> &range)
243  : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
244  ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
245  : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
246 
247  /// Return the value at the given index.
249  reference operator[](uint64_t index) const {
250  return *std::next(this->begin(), index);
251  }
252 
253  /// Return the size of this range.
254  size_t size() const { return llvm::size(*this); }
255 
256 private:
257  /// The shaped type of the parent ElementsAttr.
258  ShapedType shapeType;
259 };
260 
261 } // namespace detail
262 
263 //===----------------------------------------------------------------------===//
264 // MemRefLayoutAttrInterface
265 //===----------------------------------------------------------------------===//
266 
267 namespace detail {
268 
269 // Verify the affine map 'm' can be used as a layout specification
270 // for memref with 'shape'.
271 LogicalResult
272 verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
273  function_ref<InFlightDiagnostic()> emitError);
274 
275 } // namespace detail
276 
277 } // namespace mlir
278 
279 //===----------------------------------------------------------------------===//
280 // Tablegen Interface Declarations
281 //===----------------------------------------------------------------------===//
282 
283 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
284 
285 //===----------------------------------------------------------------------===//
286 // ElementsAttr
287 //===----------------------------------------------------------------------===//
288 
289 namespace mlir {
290 namespace detail {
291 /// Return the value at the given index.
292 template <typename IteratorT>
294  -> reference {
295  // Skip to the element corresponding to the flattened index.
296  return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
297 }
298 } // namespace detail
299 
300 /// Return the elements of this attribute as a value of type 'T'.
301 template <typename T>
302 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
303  if (std::optional<iterator<T>> iterator = try_value_begin<T>())
304  return std::move(*iterator);
305  llvm::errs()
306  << "ElementsAttr does not provide iteration facilities for type `"
307  << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n";
308  llvm_unreachable("invalid `T` for ElementsAttr::getValues");
309 }
310 template <typename T>
311 auto ElementsAttr::try_value_begin() const
312  -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
313  FailureOr<detail::ElementsAttrIndexer> indexer =
314  getValuesImpl(TypeID::get<T>());
315  if (failed(indexer))
316  return std::nullopt;
317  return iterator<T>(std::move(*indexer), 0);
318 }
319 } // namespace mlir.
320 
321 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
This class implements a generic iterator for ElementsAttr.
ptrdiff_t operator-(const ElementsAttrIterator &rhs) const
T operator*() const
Return the value at the current iterator position.
ElementsAttrIterator & operator+=(ptrdiff_t offset)
bool operator<(const ElementsAttrIterator &rhs) const
ElementsAttrIterator & operator-=(ptrdiff_t offset)
ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
bool operator==(const ElementsAttrIterator &rhs) const
This class provides iterator utilities for an ElementsAttr range.
typename IteratorT::reference reference
reference operator[](ArrayRef< uint64_t > index) const
Return the value at the given index.
size_t size() const
Return the size of this range.
reference operator[](uint64_t index) const
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range< IteratorT > &range)
Include the generated interface declarations.
Definition: CallGraph.h:229
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:148
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class provides support for indexing into the element range of an ElementsAttr.
static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr)
static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator)
Construct an indexer for a non-contiguous range starting at the given iterator.
T at(uint64_t index) const
Access the element at the given index.
ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
ElementsAttrIndexer(ElementsAttrIndexer &&rhs)