MLIR  20.0.0git
BuiltinAttributeInterfaces.h
Go to the documentation of this file.
1 //===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10 #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
11 
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Types.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <complex>
18 #include <optional>
19 
20 namespace mlir {
21 
22 //===----------------------------------------------------------------------===//
23 // ElementsAttr
24 //===----------------------------------------------------------------------===//
25 namespace detail {
26 /// This class provides support for indexing into the element range of an
27 /// ElementsAttr. It is used to opaquely wrap either a contiguous range, via
28 /// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via
29 /// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like
30 /// range, where all of the elements are layed out sequentially in memory. A
31 /// non-contiguous range implies no contiguity, and elements may even be
32 /// materialized when indexing, such as the case for a mapped_range.
34 public:
36  : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {}
38  : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
39  if (isContiguous)
40  conState = rhs.conState;
41  else
42  new (&nonConState) NonContiguousState(std::move(rhs.nonConState));
43  }
45  : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
46  if (isContiguous)
47  conState = rhs.conState;
48  else
49  new (&nonConState) NonContiguousState(rhs.nonConState);
50  }
52  if (!isContiguous)
53  nonConState.~NonContiguousState();
54  }
55 
56  /// Construct an indexer for a non-contiguous range starting at the given
57  /// iterator. A non-contiguous range implies no contiguity, and elements may
58  /// even be materialized when indexing, such as the case for a mapped_range.
59  template <typename IteratorT>
60  static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) {
61  ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat);
62  new (&indexer.nonConState)
63  NonContiguousState(std::forward<IteratorT>(iterator));
64  return indexer;
65  }
66 
67  // Construct an indexer for a contiguous range starting at the given element
68  // pointer. A contiguous range is an array-like range, where all of the
69  // elements are layed out sequentially in memory.
70  template <typename T>
71  static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) {
72  ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat);
73  new (&indexer.conState) ContiguousState(firstEltPtr);
74  return indexer;
75  }
76 
77  /// Access the element at the given index.
78  template <typename T>
79  T at(uint64_t index) const {
80  if (isSplat)
81  index = 0;
82  return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
83  }
84 
85 private:
86  ElementsAttrIndexer(bool isContiguous, bool isSplat)
87  : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {}
88 
89  /// This class contains all of the state necessary to index a contiguous
90  /// range.
91  class ContiguousState {
92  public:
93  ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
94 
95  /// Access the element at the given index.
96  template <typename T>
97  const T &at(uint64_t index) const {
98  return *(reinterpret_cast<const T *>(firstEltPtr) + index);
99  }
100 
101  private:
102  const void *firstEltPtr;
103  };
104 
105  /// This class contains all of the state necessary to index a non-contiguous
106  /// range.
107  class NonContiguousState {
108  private:
109  /// This class is used to represent the abstract base of an opaque iterator.
110  /// This allows for all iterator and element types to be completely
111  /// type-erased.
112  struct OpaqueIteratorBase {
113  virtual ~OpaqueIteratorBase() = default;
114  virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0;
115  };
116  /// This class is used to represent the abstract base of an opaque iterator
117  /// that iterates over elements of type `T`. This allows for all iterator
118  /// types to be completely type-erased.
119  template <typename T>
120  struct OpaqueIteratorValueBase : public OpaqueIteratorBase {
121  virtual T at(uint64_t index) = 0;
122  };
123  /// This class is used to represent an opaque handle to an iterator of type
124  /// `IteratorT` that iterates over elements of type `T`.
125  template <typename IteratorT, typename T>
126  struct OpaqueIterator : public OpaqueIteratorValueBase<T> {
127  template <typename ItTy, typename FuncTy, typename FuncReturnTy>
128  static void isMappedIteratorTestFn(
129  llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
130  template <typename U, typename... Args>
131  using is_mapped_iterator =
132  decltype(isMappedIteratorTestFn(std::declval<U>()));
133  template <typename U>
134  using detect_is_mapped_iterator =
135  llvm::is_detected<is_mapped_iterator, U>;
136 
137  /// Access the element within the iterator at the given index.
138  template <typename ItT>
139  static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
140  atImpl(ItT &&it, uint64_t index) {
141  return *std::next(it, index);
142  }
143  template <typename ItT>
144  static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
145  atImpl(ItT &&it, uint64_t index) {
146  // Special case mapped_iterator to avoid copying the function.
147  return it.getFunction()(*std::next(it.getCurrent(), index));
148  }
149 
150  public:
151  template <typename U>
152  OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
153  std::unique_ptr<OpaqueIteratorBase> clone() const final {
154  return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
155  }
156 
157  /// Access the element at the given index.
158  T at(uint64_t index) final { return atImpl(iterator, index); }
159 
160  private:
161  IteratorT iterator;
162  };
163 
164  public:
165  /// Construct the state with the given iterator type.
166  template <typename IteratorT, typename T = typename llvm::remove_cvref_t<
167  decltype(*std::declval<IteratorT>())>>
168  NonContiguousState(IteratorT iterator)
169  : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
170  NonContiguousState(const NonContiguousState &other)
171  : iterator(other.iterator->clone()) {}
172  NonContiguousState(NonContiguousState &&other) = default;
173 
174  /// Access the element at the given index.
175  template <typename T>
176  T at(uint64_t index) const {
177  auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
178  return valueIt->at(index);
179  }
180 
181  /// The opaque iterator state.
182  std::unique_ptr<OpaqueIteratorBase> iterator;
183  };
184 
185  /// A boolean indicating if this range is contiguous or not.
186  bool isContiguous;
187  /// A boolean indicating if this range is a splat.
188  bool isSplat;
189  /// The underlying range state.
190  union {
191  ContiguousState conState;
192  NonContiguousState nonConState;
193  };
194 };
195 
196 /// This class implements a generic iterator for ElementsAttr.
197 template <typename T>
199  : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
200  std::random_access_iterator_tag, T,
201  std::ptrdiff_t, T, T> {
202 public:
203  ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
204  : indexer(std::move(indexer)), index(dataIndex) {}
205 
206  // Boilerplate iterator methods.
207  ptrdiff_t operator-(const ElementsAttrIterator &rhs) const {
208  return index - rhs.index;
209  }
210  bool operator==(const ElementsAttrIterator &rhs) const {
211  return index == rhs.index;
212  }
213  bool operator<(const ElementsAttrIterator &rhs) const {
214  return index < rhs.index;
215  }
216  ElementsAttrIterator &operator+=(ptrdiff_t offset) {
217  index += offset;
218  return *this;
219  }
220  ElementsAttrIterator &operator-=(ptrdiff_t offset) {
221  index -= offset;
222  return *this;
223  }
224 
225  /// Return the value at the current iterator position.
226  T operator*() const { return indexer.at<T>(index); }
227 
228 private:
229  ElementsAttrIndexer indexer;
230  ptrdiff_t index;
231 };
232 
233 /// This class provides iterator utilities for an ElementsAttr range.
234 template <typename IteratorT>
235 class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
236 public:
237  using reference = typename IteratorT::reference;
238 
239  ElementsAttrRange(ShapedType shapeType,
240  const llvm::iterator_range<IteratorT> &range)
241  : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
242  ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
243  : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
244 
245  /// Return the value at the given index.
247  reference operator[](uint64_t index) const {
248  return *std::next(this->begin(), index);
249  }
250 
251  /// Return the size of this range.
252  size_t size() const { return llvm::size(*this); }
253 
254 private:
255  /// The shaped type of the parent ElementsAttr.
256  ShapedType shapeType;
257 };
258 
259 } // namespace detail
260 
261 //===----------------------------------------------------------------------===//
262 // MemRefLayoutAttrInterface
263 //===----------------------------------------------------------------------===//
264 
265 namespace detail {
266 
267 // Verify the affine map 'm' can be used as a layout specification
268 // for memref with 'shape'.
269 LogicalResult
270 verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
271  function_ref<InFlightDiagnostic()> emitError);
272 
273 } // namespace detail
274 
275 } // namespace mlir
276 
277 //===----------------------------------------------------------------------===//
278 // Tablegen Interface Declarations
279 //===----------------------------------------------------------------------===//
280 
281 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
282 
283 //===----------------------------------------------------------------------===//
284 // ElementsAttr
285 //===----------------------------------------------------------------------===//
286 
287 namespace mlir {
288 namespace detail {
289 /// Return the value at the given index.
290 template <typename IteratorT>
292  -> reference {
293  // Skip to the element corresponding to the flattened index.
294  return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
295 }
296 } // namespace detail
297 
298 /// Return the elements of this attribute as a value of type 'T'.
299 template <typename T>
300 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
301  if (std::optional<iterator<T>> iterator = try_value_begin<T>())
302  return std::move(*iterator);
303  llvm::errs()
304  << "ElementsAttr does not provide iteration facilities for type `"
305  << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n";
306  llvm_unreachable("invalid `T` for ElementsAttr::getValues");
307 }
308 template <typename T>
309 auto ElementsAttr::try_value_begin() const
310  -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
311  FailureOr<detail::ElementsAttrIndexer> indexer =
312  getValuesImpl(TypeID::get<T>());
313  if (failed(indexer))
314  return std::nullopt;
315  return iterator<T>(std::move(*indexer), 0);
316 }
317 } // namespace mlir.
318 
319 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
This class implements a generic iterator for ElementsAttr.
ptrdiff_t operator-(const ElementsAttrIterator &rhs) const
T operator*() const
Return the value at the current iterator position.
ElementsAttrIterator & operator+=(ptrdiff_t offset)
bool operator<(const ElementsAttrIterator &rhs) const
ElementsAttrIterator & operator-=(ptrdiff_t offset)
ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
bool operator==(const ElementsAttrIterator &rhs) const
This class provides iterator utilities for an ElementsAttr range.
typename IteratorT::reference reference
reference operator[](ArrayRef< uint64_t > index) const
Return the value at the given index.
size_t size() const
Return the size of this range.
reference operator[](uint64_t index) const
ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range< IteratorT > &range)
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef< int64_t > shape, function_ref< InFlightDiagnostic()> emitError)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
This class provides support for indexing into the element range of an ElementsAttr.
static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr)
static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator)
Construct an indexer for a non-contiguous range starting at the given iterator.
T at(uint64_t index) const
Access the element at the given index.
ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
ElementsAttrIndexer(ElementsAttrIndexer &&rhs)