15 #include "llvm/Support/MathExtras.h"
26 assert(
value !=
nullptr &&
"value cannot be null");
33 Type descriptorType) {
35 Value descriptor = builder.
create<LLVM::UndefOp>(loc, descriptorType);
45 MemRefType type,
Value memory) {
46 return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
51 MemRefType type,
Value memory,
Value alignedMemory) {
52 assert(type.hasStaticShape() &&
"unexpected dynamic shape");
56 assert(!ShapedType::isDynamic(
offset) &&
"expected static offset");
57 assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
58 "expected static strides");
60 auto convertedType = typeConverter.
convertType(type);
61 assert(convertedType &&
"unexpected failure in memref type conversion");
64 descr.setAllocatedPtr(builder, loc, memory);
65 descr.setAlignedPtr(builder, loc, alignedMemory);
66 descr.setConstantOffset(builder, loc,
offset);
69 for (
unsigned i = 0, e = type.getRank(); i != e; ++i) {
70 descr.setConstantSize(builder, loc, i, type.getDimSize(i));
71 descr.setConstantStride(builder, loc, i, strides[i]);
101 Type resultType, int64_t value) {
102 return builder.
create<LLVM::ConstantOp>(loc, resultType,
108 return builder.
create<LLVM::ExtractValueOp>(loc,
value,
128 return builder.
create<LLVM::ExtractValueOp>(
140 auto sizes = builder.
create<LLVM::ExtractValueOp>(
142 auto sizesPtr = builder.
create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one,
144 builder.
create<LLVM::StoreOp>(loc, sizes, sizesPtr);
147 auto resultPtr = builder.
create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr,
149 return builder.
create<LLVM::LoadOp>(loc, indexType, resultPtr);
160 unsigned pos, uint64_t size) {
167 return builder.
create<LLVM::ExtractValueOp>(
180 unsigned pos, uint64_t stride) {
186 return cast<LLVM::LLVMPointerType>(
206 ShapedType::isDynamic(offsetCst)
210 ptr = builder.
create<LLVM::GEPOp>(loc, ptr.
getType(), elementType, ptr,
233 int64_t rank = type.getRank();
234 for (
unsigned i = 0; i < rank; ++i) {
247 int64_t rank = type.getRank();
252 results.push_back(d.
alignedPtr(builder, loc));
253 results.push_back(d.
offset(builder, loc));
254 for (int64_t i = 0; i < rank; ++i)
255 results.push_back(d.
size(builder, loc, i));
256 for (int64_t i = 0; i < rank; ++i)
257 results.push_back(d.
stride(builder, loc, i));
264 return 3 + 2 * type.getRank();
305 Type descriptorType) {
306 Value descriptor = builder.
create<LLVM::UndefOp>(loc, descriptorType);
347 results.reserve(results.size() + 2);
348 results.push_back(d.
rank(builder, loc));
358 assert(values.size() == addressSpaces.size() &&
359 "must provide address space for each descriptor");
367 builder, loc, indexType,
370 sizes.reserve(sizes.size() + values.size());
371 for (
auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
380 builder, loc, indexType,
382 Value doublePointerSize =
383 builder.
create<LLVM::MulOp>(loc, indexType, two, pointerSize);
387 Value doubleRank = builder.
create<LLVM::MulOp>(loc, indexType, two,
rank);
388 Value doubleRankIncremented =
389 builder.
create<LLVM::AddOp>(loc, indexType, doubleRank, one);
391 loc, indexType, doubleRankIncremented, indexSize);
394 Value allocationSize = builder.
create<LLVM::AddOp>(
395 loc, indexType, doublePointerSize, rankIndexSize);
396 sizes.push_back(allocationSize);
402 LLVM::LLVMPointerType elemPtrType) {
408 LLVM::LLVMPointerType elemPtrType,
Value allocatedPtr) {
412 static std::pair<Value, Type>
414 LLVM::LLVMPointerType elemPtrType) {
416 return {memRefDescPtr, elemPtrPtrType};
421 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
422 auto [elementPtrPtr, elemPtrPtrType] =
426 builder.
create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
428 return builder.
create<LLVM::LoadOp>(loc, elemPtrType, alignedGep);
433 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType,
Value alignedPtr) {
434 auto [elementPtrPtr, elemPtrPtrType] =
438 builder.
create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
445 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
446 auto [elementPtrPtr, elemPtrPtrType] =
449 return builder.
create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
456 LLVM::LLVMPointerType elemPtrType) {
466 LLVM::LLVMPointerType elemPtrType,
475 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) {
477 Type structTy = LLVM::LLVMStructType::getLiteral(
478 indexTy.
getContext(), {elemPtrType, elemPtrType, indexTy, indexTy});
493 return builder.
create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep);
505 builder.
create<LLVM::StoreOp>(loc,
size, sizeStoreGep);
524 Value strideStoreGep =
526 return builder.
create<LLVM::LoadOp>(loc, indexTy, strideStoreGep);
536 Value strideStoreGep =
538 builder.
create<LLVM::StoreOp>(loc,
stride, strideStoreGep);
static std::pair< Value, Type > castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kPtrInUnrankedMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static constexpr unsigned kRankInUnrankedMemRefDescriptor
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Value size(unsigned pos)
Returns the pos-th size Value.
Value alignedPtr()
Returns the aligned pointer Value.
Value stride(unsigned pos)
Returns the pos-th stride Value.
Value allocatedPtr()
Returns the allocated pointer Value.
MemRefDescriptorView(ValueRange range)
Constructs the view from a range of values.
Value offset()
Returns the offset Value.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size)
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size)
Builds IR inserting the pos-th size into the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value offset)
Builds IR inserting the offset into the descriptor.
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride)
Builds IR inserting the pos-th stride into the descriptor.
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the aligned pointer into the descriptor.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset)
Builds IR inserting the offset into the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride)
MemRefDescriptor(Value descriptor)
Construct a helper for the given descriptor value.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the allocated pointer into the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Helper class to produce LLVM dialect operations extracting or inserting values to a struct.
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const
Builds IR to extract a value from the struct at position pos.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static Value stride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR extracting the stride[index] from the descriptor.
static Value size(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index)
Builds IR extracting the size[index] from the descriptor.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
UnrankedMemRefDescriptor(Value descriptor)
Construct a helper for the given descriptor value.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...