15 #ifndef MLIR_EXECUTIONENGINE_CRUNNERUTILS_H 16 #define MLIR_EXECUTIONENGINE_CRUNNERUTILS_H 19 #ifndef MLIR_CRUNNERUTILS_EXPORT 20 #ifdef mlir_c_runner_utils_EXPORTS 22 #define MLIR_CRUNNERUTILS_EXPORT __declspec(dllexport) 23 #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 26 #define MLIR_CRUNNERUTILS_EXPORT __declspec(dllimport) 27 #endif // mlir_c_runner_utils_EXPORTS 28 #endif // MLIR_CRUNNERUTILS_EXPORT 31 #define MLIR_CRUNNERUTILS_EXPORT __attribute__((visibility("default"))) 32 #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS 38 #include <initializer_list> 46 constexpr
bool isPowerOf2(
int n) {
return (!(n & (n - 1))); }
52 template <
typename T,
int Dim,
bool IsPowerOf2>
55 template <
typename T,
int Dim>
62 inline const T &
operator[](
unsigned i)
const {
return vector[i]; }
70 template <
typename T,
int Dim>
73 static_assert(
nextPowerOf2(
sizeof(T[
Dim])) >
sizeof(T[Dim]),
"size error");
74 static_assert(
nextPowerOf2(
sizeof(T[Dim])) < 2 *
sizeof(T[Dim]),
78 inline const T &
operator[](
unsigned i)
const {
return vector[i]; }
88 template <
typename T,
int Dim,
int... Dims>
101 template <
typename T,
int Dim>
104 mlir::detail::isPowerOf2(sizeof(T[Dim]))> {
107 template <
int D1,
typename T>
109 template <
int D1,
int D2,
typename T>
111 template <
int D1,
int D2,
int D3,
typename T>
113 template <
int D1,
int D2,
int D3,
int D4,
typename T>
118 for (
unsigned i = 1; i < N; ++i)
119 *(res + i - 1) = arr[i];
125 template <
typename T,
int Rank>
129 template <
typename T,
int N>
137 template <
typename Range,
138 typename sfinae = decltype(std::declval<Range>().begin())>
140 assert(indices.size() == N &&
141 "indices should match rank in memref subscript");
142 int64_t curOffset = offset;
143 for (
int dim = N - 1; dim >= 0; --dim) {
144 int64_t currentIndex = *(indices.begin() + dim);
145 assert(currentIndex < sizes[dim] &&
"Index overflow");
146 curOffset += currentIndex * strides[dim];
148 return data[curOffset];
157 res.basePtr = basePtr;
159 res.offset = offset + idx * strides[0];
160 dropFront<N>(sizes, res.sizes);
161 dropFront<N>(strides, res.strides);
167 template <
typename T>
175 template <
typename Range,
176 typename sfinae = decltype(std::declval<Range>().begin())>
178 assert(indices.size() == 1 &&
179 "indices should match rank in memref subscript");
180 return (*
this)[*indices.begin()];
186 T &
operator[](int64_t idx) {
return *(data + offset + idx * strides[0]); }
190 template <
typename T>
196 template <
typename Range,
197 typename sfinae = decltype(std::declval<Range>().begin())>
199 assert((indices.size() == 0) &&
200 "Expect empty indices for 0-rank memref subscript");
209 template <
typename T,
int Rank>
214 : offset(offset), descriptor(descriptor) {}
217 while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
218 offset -= indices[dim] * descriptor.strides[dim];
227 offset += descriptor.strides[dim];
234 const std::array<int64_t, Rank> &
getIndices() {
return indices; }
237 return other.offset == offset && &other.descriptor == &descriptor;
241 return !(*
this == other);
249 std::array<int64_t, Rank> indices = {};
255 template <
typename T>
259 : elt(descriptor.data + offset) {}
274 static const std::array<int64_t, 0> indices = {};
279 return other.elt == elt;
283 return !(*
this == other);
295 template <
typename T>
305 template <
typename T>
309 : rank(0), basePtr(memRef.basePtr), data(memRef.data),
310 offset(memRef.offset), sizes(nullptr), strides(nullptr) {}
313 : rank(N), basePtr(memRef.basePtr), data(memRef.data),
314 offset(memRef.offset), sizes(memRef.sizes), strides(memRef.strides) {}
316 : rank(memRef.rank) {
320 offset = desc->offset;
321 sizes = rank == 0 ? nullptr : desc->sizes;
322 strides = sizes + rank;
358 #endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u)
Include the generated interface declarations.
T & operator[](Range &&indices)
StridedMemrefIterator(StridedMemRefType< T, Rank > &descriptor, int64_t offset=0)
bool operator!=(const StridedMemrefIterator &other) const
StridedMemRef descriptor type specialized for rank 1.
MLIR_CRUNNERUTILS_EXPORT void printF64(double d)
StridedMemrefIterator< T, 1 > end()
const std::array< int64_t, Rank > & getIndices()
StridedMemrefIterator< T, 0 > end()
StridedMemRefType< T, N - 1 > operator[](int64_t idx)
MLIR_CRUNNERUTILS_EXPORT void printF32(float f)
T & operator[](Range indices)
Iterate over all elements in a 0-ranked strided memref.
MLIR_CRUNNERUTILS_EXPORT void printOpen()
bool operator!=(const StridedMemrefIterator &other) const
MLIR_CRUNNERUTILS_EXPORT void printClose()
MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i)
MLIR_CRUNNERUTILS_EXPORT void printComma()
MLIR_CRUNNERUTILS_EXPORT void memrefCopy(int64_t elemSize, ::UnrankedMemRefType< char > *src, ::UnrankedMemRefType< char > *dst)
MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops)
StridedMemrefIterator< T, N > end()
T & operator[](Range indices)
MLIR_CRUNNERUTILS_EXPORT void printNewline()
constexpr unsigned nextPowerOf2(int n)
StridedMemrefIterator< T, 0 > begin()
#define MLIR_CRUNNERUTILS_EXPORT
StridedMemRef descriptor type specialized for rank 0.
const Vector< T, Dims... > & operator[](unsigned i) const
DynamicMemRefType(const ::UnrankedMemRefType< T > &memRef)
StridedMemrefIterator< T, 0 > & operator++()
const T & operator[](unsigned i) const
T & operator[](unsigned i)
DynamicMemRefType(const StridedMemRefType< T, 0 > &memRef)
DynamicMemRefType(const StridedMemRefType< T, N > &memRef)
void dropFront(int64_t arr[N], int64_t *res)
Iterate over all elements in a strided memref.
Dim
Dimension level type for a tensor (undef means index does not appear).
const std::array< int64_t, 0 > & getIndices()
StridedMemrefIterator< T, N > begin()
constexpr bool isPowerOf2(int n)
T & operator[](int64_t idx)
StridedMemRef descriptor type with static rank.
StridedMemrefIterator< T, 1 > begin()
const T & operator[](unsigned i) const
bool operator==(const StridedMemrefIterator &other) const
bool operator==(const StridedMemrefIterator &other) const
Vector< T, Dims... > & operator[](unsigned i)
StridedMemrefIterator< T, Rank > & operator++()
StridedMemrefIterator(StridedMemRefType< T, 0 > &descriptor, int64_t offset=0)
MLIR_CRUNNERUTILS_EXPORT double rtclock()
T & operator[](unsigned i)