MLIR  16.0.0git
CRunnerUtils.h
Go to the documentation of this file.
1 //===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares basic classes and functions to manipulate structured MLIR
10 // types at runtime. Entities in this file must be compliant with C++11 and be
11 // retargetable, including on targets without a C++ runtime.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_EXECUTIONENGINE_CRUNNERUTILS_H
16 #define MLIR_EXECUTIONENGINE_CRUNNERUTILS_H
17 
18 #ifdef _WIN32
19 #ifndef MLIR_CRUNNERUTILS_EXPORT
20 #ifdef mlir_c_runner_utils_EXPORTS
21 // We are building this library
22 #define MLIR_CRUNNERUTILS_EXPORT __declspec(dllexport)
23 #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
24 #else
25 // We are using this library
26 #define MLIR_CRUNNERUTILS_EXPORT __declspec(dllimport)
27 #endif // mlir_c_runner_utils_EXPORTS
28 #endif // MLIR_CRUNNERUTILS_EXPORT
29 #else // _WIN32
30 // Non-windows: use visibility attributes.
31 #define MLIR_CRUNNERUTILS_EXPORT __attribute__((visibility("default")))
32 #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
33 #endif // _WIN32
34 
35 #include <array>
36 #include <cassert>
37 #include <cstdint>
38 #include <initializer_list>
39 #include <vector>
40 
41 //===----------------------------------------------------------------------===//
42 // Codegen-compatible structures for Vector type.
43 //===----------------------------------------------------------------------===//
44 namespace mlir {
45 namespace detail {
46 
47 constexpr bool isPowerOf2(int n) { return (!(n & (n - 1))); }
48 
49 constexpr unsigned nextPowerOf2(int n) {
50  return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2)));
51 }
52 
53 template <typename T, int Dim, bool IsPowerOf2>
54 struct Vector1D;
55 
56 template <typename T, int Dim>
57 struct Vector1D<T, Dim, /*IsPowerOf2=*/true> {
59  static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]),
60  "size error");
61  }
62  inline T &operator[](unsigned i) { return vector[i]; }
63  inline const T &operator[](unsigned i) const { return vector[i]; }
64 
65 private:
66  T vector[Dim];
67 };
68 
69 // 1-D vector, padded to the next power of 2 allocation.
70 // Specialization occurs to avoid zero size arrays (which fail in -Werror).
71 template <typename T, int Dim>
72 struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
74  static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error");
75  static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
76  "size error");
77  }
78  inline T &operator[](unsigned i) { return vector[i]; }
79  inline const T &operator[](unsigned i) const { return vector[i]; }
80 
81 private:
82  T vector[Dim];
83  char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
84 };
85 } // namespace detail
86 } // namespace mlir
87 
88 // N-D vectors recurse down to 1-D.
89 template <typename T, int Dim, int... Dims>
90 struct Vector {
91  inline Vector<T, Dims...> &operator[](unsigned i) { return vector[i]; }
92  inline const Vector<T, Dims...> &operator[](unsigned i) const {
93  return vector[i];
94  }
95 
96 private:
97  Vector<T, Dims...> vector[Dim];
98 };
99 
100 // 1-D vectors in LLVM are automatically padded to the next power of 2.
101 // We insert explicit padding in to account for this.
102 template <typename T, int Dim>
103 struct Vector<T, Dim>
104  : public mlir::detail::Vector1D<T, Dim,
105  mlir::detail::isPowerOf2(sizeof(T[Dim]))> {
106 };
107 
108 template <int D1, typename T>
110 template <int D1, int D2, typename T>
112 template <int D1, int D2, int D3, typename T>
114 template <int D1, int D2, int D3, int D4, typename T>
116 
117 template <int N>
118 void dropFront(int64_t arr[N], int64_t *res) {
119  for (unsigned i = 1; i < N; ++i)
120  *(res + i - 1) = arr[i];
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Codegen-compatible structures for StridedMemRef type.
125 //===----------------------------------------------------------------------===//
126 template <typename T, int Rank>
128 
129 /// StridedMemRef descriptor type with static rank.
130 template <typename T, int N>
133  T *data;
134  int64_t offset;
135  int64_t sizes[N];
136  int64_t strides[N];
137 
138  template <typename Range,
139  typename sfinae = decltype(std::declval<Range>().begin())>
140  T &operator[](Range &&indices) {
141  assert(indices.size() == N &&
142  "indices should match rank in memref subscript");
143  int64_t curOffset = offset;
144  for (int dim = N - 1; dim >= 0; --dim) {
145  int64_t currentIndex = *(indices.begin() + dim);
146  assert(currentIndex < sizes[dim] && "Index overflow");
147  curOffset += currentIndex * strides[dim];
148  }
149  return data[curOffset];
150  }
151 
152  StridedMemrefIterator<T, N> begin() { return {*this}; }
153  StridedMemrefIterator<T, N> end() { return {*this, -1}; }
154 
155  // This operator[] is extremely slow and only for sugaring purposes.
156  StridedMemRefType<T, N - 1> operator[](int64_t idx) {
157  StridedMemRefType<T, N - 1> res;
158  res.basePtr = basePtr;
159  res.data = data;
160  res.offset = offset + idx * strides[0];
161  dropFront<N>(sizes, res.sizes);
162  dropFront<N>(strides, res.strides);
163  return res;
164  }
165 };
166 
167 /// StridedMemRef descriptor type specialized for rank 1.
168 template <typename T>
169 struct StridedMemRefType<T, 1> {
171  T *data;
172  int64_t offset;
173  int64_t sizes[1];
174  int64_t strides[1];
175 
176  template <typename Range,
177  typename sfinae = decltype(std::declval<Range>().begin())>
178  T &operator[](Range indices) {
179  assert(indices.size() == 1 &&
180  "indices should match rank in memref subscript");
181  return (*this)[*indices.begin()];
182  }
183 
184  StridedMemrefIterator<T, 1> begin() { return {*this}; }
185  StridedMemrefIterator<T, 1> end() { return {*this, -1}; }
186 
187  T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
188 };
189 
190 /// StridedMemRef descriptor type specialized for rank 0.
191 template <typename T>
192 struct StridedMemRefType<T, 0> {
194  T *data;
195  int64_t offset;
196 
197  template <typename Range,
198  typename sfinae = decltype(std::declval<Range>().begin())>
199  T &operator[](Range indices) {
200  assert((indices.size() == 0) &&
201  "Expect empty indices for 0-rank memref subscript");
202  return data[offset];
203  }
204 
205  StridedMemrefIterator<T, 0> begin() { return {*this}; }
206  StridedMemrefIterator<T, 0> end() { return {*this, 1}; }
207 };
208 
209 /// Iterate over all elements in a strided memref.
210 template <typename T, int Rank>
211 class StridedMemrefIterator {
212 public:
213  using iterator_category = std::forward_iterator_tag;
214  using value_type = T;
215  using difference_type = std::ptrdiff_t;
216  using pointer = T *;
217  using reference = T &;
218 
220  int64_t offset = 0)
221  : offset(offset), descriptor(&descriptor) {}
223  int dim = Rank - 1;
224  while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
225  offset -= indices[dim] * descriptor->strides[dim];
226  indices[dim] = 0;
227  --dim;
228  }
229  if (dim < 0) {
230  offset = -1;
231  return *this;
232  }
233  ++indices[dim];
234  offset += descriptor->strides[dim];
235  return *this;
236  }
237 
238  reference operator*() { return descriptor->data[offset]; }
239  pointer operator->() { return &descriptor->data[offset]; }
240 
241  const std::array<int64_t, Rank> &getIndices() { return indices; }
242 
243  bool operator==(const StridedMemrefIterator &other) const {
244  return other.offset == offset && other.descriptor == descriptor;
245  }
246 
247  bool operator!=(const StridedMemrefIterator &other) const {
248  return !(*this == other);
249  }
250 
251 private:
252  /// Offset in the buffer. This can be derived from the indices and the
253  /// descriptor.
254  int64_t offset = 0;
255 
256  /// Array of indices in the multi-dimensional memref.
257  std::array<int64_t, Rank> indices = {};
258 
259  /// Descriptor for the strided memref.
260  StridedMemRefType<T, Rank> *descriptor;
261 };
262 
263 /// Iterate over all elements in a 0-ranked strided memref.
264 template <typename T>
266 public:
267  using iterator_category = std::forward_iterator_tag;
268  using value_type = T;
269  using difference_type = std::ptrdiff_t;
270  using pointer = T *;
271  using reference = T &;
272 
273  StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
274  : elt(descriptor.data + offset) {}
275 
277  ++elt;
278  return *this;
279  }
280 
281  reference operator*() { return *elt; }
282  pointer operator->() { return elt; }
283 
284  // There are no indices for a 0-ranked memref, but this API is provided for
285  // consistency with the general case.
286  const std::array<int64_t, 0> &getIndices() {
287  // Since this is a 0-array of indices we can keep a single global const
288  // copy.
289  static const std::array<int64_t, 0> indices = {};
290  return indices;
291  }
292 
293  bool operator==(const StridedMemrefIterator &other) const {
294  return other.elt == elt;
295  }
296 
297  bool operator!=(const StridedMemrefIterator &other) const {
298  return !(*this == other);
299  }
300 
301 private:
302  /// Pointer to the single element in the zero-ranked memref.
303  T *elt;
304 };
305 
306 //===----------------------------------------------------------------------===//
307 // Codegen-compatible structure for UnrankedMemRef type.
308 //===----------------------------------------------------------------------===//
309 // Unranked MemRef
310 template <typename T>
312  int64_t rank;
313  void *descriptor;
314 };
315 
316 //===----------------------------------------------------------------------===//
317 // DynamicMemRefType type.
318 //===----------------------------------------------------------------------===//
319 template <typename T>
321 
322 // A reference to one of the StridedMemRef types.
323 template <typename T>
325 public:
326  int64_t rank;
328  T *data;
329  int64_t offset;
330  const int64_t *sizes;
331  const int64_t *strides;
332 
334  : rank(0), basePtr(memRef.basePtr), data(memRef.data),
335  offset(memRef.offset), sizes(nullptr), strides(nullptr) {}
336  template <int N>
338  : rank(N), basePtr(memRef.basePtr), data(memRef.data),
339  offset(memRef.offset), sizes(memRef.sizes), strides(memRef.strides) {}
340  explicit DynamicMemRefType(const ::UnrankedMemRefType<T> &memRef)
341  : rank(memRef.rank) {
342  auto *desc = static_cast<StridedMemRefType<T, 1> *>(memRef.descriptor);
343  basePtr = desc->basePtr;
344  data = desc->data;
345  offset = desc->offset;
346  sizes = rank == 0 ? nullptr : desc->sizes;
347  strides = sizes + rank;
348  }
349 
350  template <typename Range,
351  typename sfinae = decltype(std::declval<Range>().begin())>
352  T &operator[](Range &&indices) {
353  assert(indices.size() == rank &&
354  "indices should match rank in memref subscript");
355  if (rank == 0)
356  return data[offset];
357 
358  int64_t curOffset = offset;
359  for (int dim = rank - 1; dim >= 0; --dim) {
360  int64_t currentIndex = *(indices.begin() + dim);
361  assert(currentIndex < sizes[dim] && "Index overflow");
362  curOffset += currentIndex * strides[dim];
363  }
364  return data[curOffset];
365  }
366 
367  DynamicMemRefIterator<T> begin() { return {*this}; }
368  DynamicMemRefIterator<T> end() { return {*this, -1}; }
369 
370  // This operator[] is extremely slow and only for sugaring purposes.
372  assert(rank > 0 && "can't make a subscript of a zero ranked array");
373 
374  DynamicMemRefType<T> res(*this);
375  --res.rank;
376  res.offset += idx * res.strides[0];
377  ++res.sizes;
378  ++res.strides;
379  return res;
380  }
381 
382  // This operator* can be used in conjunction with the previous operator[] in
383  // order to access the underlying value in case of zero-ranked memref.
384  T &operator*() {
385  assert(rank == 0 && "not a zero-ranked memRef");
386  return data[offset];
387  }
388 };
389 
390 /// Iterate over all elements in a dynamic memref.
391 template <typename T>
392 class DynamicMemRefIterator {
393 public:
394  using iterator_category = std::forward_iterator_tag;
395  using value_type = T;
396  using difference_type = std::ptrdiff_t;
397  using pointer = T *;
398  using reference = T &;
399 
400  DynamicMemRefIterator(DynamicMemRefType<T> &descriptor, int64_t offset = 0)
401  : offset(offset), descriptor(&descriptor) {
402  indices.resize(descriptor.rank, 0);
403  }
404 
406  if (descriptor->rank == 0) {
407  offset = -1;
408  return *this;
409  }
410 
411  int dim = descriptor->rank - 1;
412 
413  while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
414  offset -= indices[dim] * descriptor->strides[dim];
415  indices[dim] = 0;
416  --dim;
417  }
418 
419  if (dim < 0) {
420  offset = -1;
421  return *this;
422  }
423 
424  ++indices[dim];
425  offset += descriptor->strides[dim];
426  return *this;
427  }
428 
429  reference operator*() { return descriptor->data[offset]; }
430  pointer operator->() { return &descriptor->data[offset]; }
431 
432  const std::vector<int64_t> &getIndices() { return indices; }
433 
434  bool operator==(const DynamicMemRefIterator &other) const {
435  return other.offset == offset && other.descriptor == descriptor;
436  }
437 
438  bool operator!=(const DynamicMemRefIterator &other) const {
439  return !(*this == other);
440  }
441 
442 private:
443  /// Offset in the buffer. This can be derived from the indices and the
444  /// descriptor.
445  int64_t offset = 0;
446 
447  /// Array of indices in the multi-dimensional memref.
448  std::vector<int64_t> indices = {};
449 
450  /// Descriptor for the dynamic memref.
451  DynamicMemRefType<T> *descriptor;
452 };
453 
454 //===----------------------------------------------------------------------===//
455 // Small runtime support library for memref.copy lowering during codegen.
456 //===----------------------------------------------------------------------===//
457 extern "C" MLIR_CRUNNERUTILS_EXPORT void
458 memrefCopy(int64_t elemSize, ::UnrankedMemRefType<char> *src,
459  ::UnrankedMemRefType<char> *dst);
460 
461 //===----------------------------------------------------------------------===//
462 // Small runtime support library for vector.print lowering during codegen.
463 //===----------------------------------------------------------------------===//
464 extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
465 extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
466 extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
467 extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
468 extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
469 extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
470 extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
471 extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
472 
473 //===----------------------------------------------------------------------===//
474 // Small runtime support library for timing execution and printing GFLOPS
475 //===----------------------------------------------------------------------===//
476 extern "C" MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops);
477 extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock();
478 
479 #endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H
MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u)
Include the generated interface declarations.
std::ptrdiff_t difference_type
Definition: CRunnerUtils.h:396
DynamicMemRefIterator< T > begin()
Definition: CRunnerUtils.h:367
T & operator[](Range &&indices)
Definition: CRunnerUtils.h:140
StridedMemrefIterator(StridedMemRefType< T, Rank > &descriptor, int64_t offset=0)
Definition: CRunnerUtils.h:219
bool operator!=(const StridedMemrefIterator &other) const
Definition: CRunnerUtils.h:297
std::ptrdiff_t difference_type
Definition: CRunnerUtils.h:215
StridedMemRef descriptor type specialized for rank 1.
Definition: CRunnerUtils.h:169
MLIR_CRUNNERUTILS_EXPORT void printF64(double d)
StridedMemrefIterator< T, 1 > end()
Definition: CRunnerUtils.h:185
const std::array< int64_t, Rank > & getIndices()
Definition: CRunnerUtils.h:241
StridedMemrefIterator< T, 0 > end()
Definition: CRunnerUtils.h:206
StridedMemRefType< T, N - 1 > operator[](int64_t idx)
Definition: CRunnerUtils.h:156
MLIR_CRUNNERUTILS_EXPORT void printF32(float f)
T & operator[](Range indices)
Definition: CRunnerUtils.h:199
DynamicMemRefType< T > operator[](int64_t idx)
Definition: CRunnerUtils.h:371
Iterate over all elements in a 0-ranked strided memref.
Definition: CRunnerUtils.h:265
MLIR_CRUNNERUTILS_EXPORT void printOpen()
const int64_t * sizes
Definition: CRunnerUtils.h:330
bool operator!=(const StridedMemrefIterator &other) const
Definition: CRunnerUtils.h:247
MLIR_CRUNNERUTILS_EXPORT void printClose()
DynamicMemRefIterator(DynamicMemRefType< T > &descriptor, int64_t offset=0)
Definition: CRunnerUtils.h:400
bool operator==(const DynamicMemRefIterator &other) const
Definition: CRunnerUtils.h:434
T & operator[](Range &&indices)
Definition: CRunnerUtils.h:352
MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i)
MLIR_CRUNNERUTILS_EXPORT void printComma()
MLIR_CRUNNERUTILS_EXPORT void memrefCopy(int64_t elemSize, ::UnrankedMemRefType< char > *src, ::UnrankedMemRefType< char > *dst)
MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops)
StridedMemrefIterator< T, N > end()
Definition: CRunnerUtils.h:153
std::forward_iterator_tag iterator_category
Definition: CRunnerUtils.h:394
T & operator[](Range indices)
Definition: CRunnerUtils.h:178
MLIR_CRUNNERUTILS_EXPORT void printNewline()
constexpr unsigned nextPowerOf2(int n)
Definition: CRunnerUtils.h:49
StridedMemrefIterator< T, 0 > begin()
Definition: CRunnerUtils.h:205
#define MLIR_CRUNNERUTILS_EXPORT
Definition: CRunnerUtils.h:31
StridedMemRef descriptor type specialized for rank 0.
Definition: CRunnerUtils.h:192
const Vector< T, Dims... > & operator[](unsigned i) const
Definition: CRunnerUtils.h:92
DynamicMemRefIterator< T > & operator++()
Definition: CRunnerUtils.h:405
DynamicMemRefType(const ::UnrankedMemRefType< T > &memRef)
Definition: CRunnerUtils.h:340
const std::vector< int64_t > & getIndices()
Definition: CRunnerUtils.h:432
StridedMemrefIterator< T, 0 > & operator++()
Definition: CRunnerUtils.h:276
const T & operator[](unsigned i) const
Definition: CRunnerUtils.h:79
DynamicMemRefType(const StridedMemRefType< T, 0 > &memRef)
Definition: CRunnerUtils.h:333
DynamicMemRefType(const StridedMemRefType< T, N > &memRef)
Definition: CRunnerUtils.h:337
void dropFront(int64_t arr[N], int64_t *res)
Definition: CRunnerUtils.h:118
Iterate over all elements in a strided memref.
Definition: CRunnerUtils.h:127
const std::array< int64_t, 0 > & getIndices()
Definition: CRunnerUtils.h:286
StridedMemrefIterator< T, N > begin()
Definition: CRunnerUtils.h:152
constexpr bool isPowerOf2(int n)
Definition: CRunnerUtils.h:47
std::forward_iterator_tag iterator_category
Definition: CRunnerUtils.h:213
DynamicMemRefIterator< T > end()
Definition: CRunnerUtils.h:368
T & operator[](int64_t idx)
Definition: CRunnerUtils.h:187
StridedMemRef descriptor type with static rank.
Definition: CRunnerUtils.h:131
StridedMemrefIterator< T, 1 > begin()
Definition: CRunnerUtils.h:184
const T & operator[](unsigned i) const
Definition: CRunnerUtils.h:63
bool operator==(const StridedMemrefIterator &other) const
Definition: CRunnerUtils.h:243
bool operator==(const StridedMemrefIterator &other) const
Definition: CRunnerUtils.h:293
Vector< T, Dims... > & operator[](unsigned i)
Definition: CRunnerUtils.h:91
bool operator!=(const DynamicMemRefIterator &other) const
Definition: CRunnerUtils.h:438
const int64_t * strides
Definition: CRunnerUtils.h:331
StridedMemrefIterator< T, Rank > & operator++()
Definition: CRunnerUtils.h:222
std::forward_iterator_tag iterator_category
Definition: CRunnerUtils.h:267
StridedMemrefIterator(StridedMemRefType< T, 0 > &descriptor, int64_t offset=0)
Definition: CRunnerUtils.h:273
MLIR_CRUNNERUTILS_EXPORT double rtclock()
Iterate over all elements in a dynamic memref.
Definition: CRunnerUtils.h:320