19#include "llvm/ADT/ArrayRef.h"
20#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/raw_ostream.h"
29#include <initializer_list>
33#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
34#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
47 assert(
shape.size() == N &&
"expect shape specification to match rank");
48 std::array<int64_t, N> res;
50 for (
int64_t idx = N - 1; idx >= 0; --idx) {
51 assert(
shape[idx] >= 0 &&
52 "size must be non-negative for all shape dimensions");
54 running *=
shape[idx];
64template <
int N,
typename T>
68 assert(
shape.size() == N);
69 assert(shapeAlloc.size() == N);
72 descriptor.
data = alignedPtr;
76 std::copy(strides.begin(), strides.end(), descriptor.
strides);
85template <
int N,
typename T>
89 assert(shape.size() == N);
90 assert(shapeAlloc.size() == N);
93 descriptor.
data = alignedPtr;
105 std::optional<uint64_t> alignment = std::optional<uint64_t>()) {
106 assert(
sizeof(T) <= UINT_MAX &&
"Elemental type overflows");
107 auto size = nElements *
sizeof(T);
108 auto desiredAlignment = alignment.value_or(
nextPowerOf2(
sizeof(T)));
109 assert((desiredAlignment & (desiredAlignment - 1)) == 0);
110 assert(desiredAlignment >=
sizeof(T));
111 T *data =
reinterpret_cast<T *
>(allocFun(size + desiredAlignment));
112 uintptr_t addr =
reinterpret_cast<uintptr_t
>(data);
113 uintptr_t
rem = addr % desiredAlignment;
114 T *alignedData = (
rem == 0)
116 :
reinterpret_cast<T *
>(addr + (desiredAlignment -
rem));
117 assert(
reinterpret_cast<uintptr_t
>(alignedData) % desiredAlignment == 0);
118 return std::make_pair(data, alignedData);
135template <
typename T,
unsigned Rank>
150 std::optional<uint64_t> alignment = std::optional<uint64_t>(),
156 : freeFunc(freeFun) {
157 if (shapeAlloc.empty())
159 assert(shape.size() == Rank);
160 assert(shapeAlloc.size() == Rank);
161 for (
unsigned i = 0; i < Rank; ++i)
162 assert(shape[i] <= shapeAlloc[i] &&
163 "shapeAlloc must be greater than or equal to shape");
164 int64_t nElements = 1;
165 for (int64_t s : shapeAlloc)
167 auto [allocatedPtr, alignedData] =
170 allocatedPtr, alignedData, shape, shapeAlloc);
172 for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
173 end = descriptor.end();
175 init(*it, it.getIndices());
177 memset(alignedData, 0, nElements *
sizeof(T));
182 : freeFunc(freeFunc), descriptor(descriptor) {}
185 freeFunc(descriptor);
190 freeFunc = other.freeFunc;
191 descriptor = other.descriptor;
192 other.freeFunc =
nullptr;
193 memset(&other.descriptor, 0,
sizeof(other.descriptor));
OwningMemRef(OwningMemRef &&other)
OwningMemRef(ArrayRef< int64_t > shape, ArrayRef< int64_t > shapeAlloc={}, ElementWiseVisitor< T > init={}, std::optional< uint64_t > alignment=std::optional< uint64_t >(), AllocFunType allocFun=&::malloc, std::function< void(StridedMemRefType< T, Rank >)> freeFun=[](StridedMemRefType< T, Rank > descriptor) { ::free(descriptor.basePtr);})
Allocate a new dense StridedMemrefRef with a given shape.
OwningMemRef & operator=(const OwningMemRef &)=delete
OwningMemRef(const OwningMemRef &)=delete
StridedMemRefType< T, Rank > DescriptorType
OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc)
Take ownership of an existing descriptor with a custom deleter.
DescriptorType * operator->()
DescriptorType & operator*()
T & operator[](std::initializer_list< int64_t > indices)
std::function< void(DescriptorType)> FreeFunType
OwningMemRef & operator=(const OwningMemRef &&other)
constexpr unsigned nextPowerOf2(int n)
std::pair< T *, T * > allocAligned(size_t nElements, AllocFunType allocFun=&::malloc, std::optional< uint64_t > alignment=std::optional< uint64_t >())
Align nElements of type T with an optional alignment.
std::array< int64_t, N > makeStrides(ArrayRef< int64_t > shape)
Given a shape with sizes greater than 0 along all dimensions, returns the distance,...
std::enable_if<(N >=1), StridedMemRefType< T, N > >::type makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef< int64_t > shape, ArrayRef< int64_t > shapeAlloc)
Build a StridedMemRefDescriptor<T, N> that matches the MLIR ABI.
Include the generated interface declarations.
llvm::function_ref< void *(size_t)> AllocFunType
llvm::function_ref< void(T &ptr, ArrayRef< int64_t >)> ElementWiseVisitor
Convenient callback to "visit" a memref element by element.
StridedMemRef descriptor type with static rank.