15#include "llvm/Support/MathExtras.h"
26 assert(
value !=
nullptr &&
"value cannot be null");
27 indexType = cast<LLVM::LLVMStructType>(
value.getType())
33 Type descriptorType) {
35 Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType);
45 MemRefType type,
Value memory) {
46 return fromStaticShape(builder, loc, typeConverter, type, memory, memory);
51 MemRefType type,
Value memory,
Value alignedMemory) {
52 assert(type.hasStaticShape() &&
"unexpected dynamic shape");
55 auto [strides,
offset] = type.getStridesAndOffset();
56 assert(ShapedType::isStatic(
offset) &&
"expected static offset");
57 assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
58 "expected static strides");
60 auto convertedType = typeConverter.convertType(type);
61 assert(convertedType &&
"unexpected failure in memref type conversion");
64 descr.setAllocatedPtr(builder, loc, memory);
65 descr.setAlignedPtr(builder, loc, alignedMemory);
66 descr.setConstantOffset(builder, loc,
offset);
69 for (
unsigned i = 0, e = type.getRank(); i != e; ++i) {
70 descr.setConstantSize(builder, loc, i, type.getDimSize(i));
71 descr.setConstantStride(builder, loc, i, strides[i]);
102 return LLVM::ConstantOp::create(builder, loc, resultType,
108 return LLVM::ExtractValueOp::create(builder, loc,
value,
128 return LLVM::ExtractValueOp::create(
135 auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
137 auto ptrTy = LLVM::LLVMPointerType::get(builder.
getContext());
141 auto sizes = LLVM::ExtractValueOp::create(
144 auto sizesPtr = LLVM::AllocaOp::create(builder, loc, ptrTy, arrayTy, one,
146 LLVM::StoreOp::create(builder, loc, sizes, sizesPtr);
149 auto resultPtr = LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, sizesPtr,
151 return LLVM::LoadOp::create(builder, loc, indexType, resultPtr);
157 value = LLVM::InsertValueOp::create(
163 unsigned pos, uint64_t
size) {
170 return LLVM::ExtractValueOp::create(
178 value = LLVM::InsertValueOp::create(
184 unsigned pos, uint64_t
stride) {
190 return cast<LLVM::LLVMPointerType>(
191 cast<LLVM::LLVMStructType>(
value.getType())
200 auto [strides, offsetCst] = type.getStridesAndOffset();
210 ShapedType::isDynamic(offsetCst)
213 Type elementType = converter.convertType(type.getElementType());
214 ptr = LLVM::GEPOp::create(builder, loc,
ptr.getType(), elementType,
ptr,
230 Type llvmType = converter.convertType(type);
238 for (
unsigned i = 0; i < rank; ++i) {
256 results.push_back(d.
alignedPtr(builder, loc));
257 results.push_back(d.
offset(builder, loc));
258 for (
int64_t i = 0; i < rank; ++i)
259 results.push_back(d.
size(builder, loc, i));
260 for (
int64_t i = 0; i < rank; ++i)
261 results.push_back(d.
stride(builder, loc, i));
268 return 3 + 2 * type.getRank();
309 Type descriptorType) {
310 Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType);
337 Type llvmType = converter.convertType(type);
351 results.reserve(results.size() + 2);
352 results.push_back(d.
rank(builder, loc));
366 builder, loc, indexType,
377 builder, loc, indexType,
379 Value doublePointerSize =
380 LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
384 Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two,
rank);
385 Value doubleRankIncremented =
386 LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
387 Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
388 doubleRankIncremented, indexSize);
391 Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
392 doublePointerSize, rankIndexSize);
393 return allocationSize;
398 LLVM::LLVMPointerType elemPtrType) {
399 return LLVM::LoadOp::create(builder, loc, elemPtrType,
memRefDescPtr);
408static std::pair<Value, Type>
410 LLVM::LLVMPointerType elemPtrType) {
411 auto elemPtrPtrType = LLVM::LLVMPointerType::get(builder.
getContext());
412 return {memRefDescPtr, elemPtrPtrType};
418 auto [elementPtrPtr, elemPtrPtrType] =
422 LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
424 return LLVM::LoadOp::create(builder, loc, elemPtrType, alignedGep);
430 auto [elementPtrPtr, elemPtrPtrType] =
434 LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
436 LLVM::StoreOp::create(builder, loc,
alignedPtr, alignedGep);
442 auto [elementPtrPtr, elemPtrPtrType] =
445 return LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType,
452 LLVM::LLVMPointerType elemPtrType) {
455 return LLVM::LoadOp::create(builder, loc, typeConverter.
getIndexType(),
462 LLVM::LLVMPointerType elemPtrType,
466 LLVM::StoreOp::create(builder, loc,
offset, offsetPtr);
473 Type structTy = LLVM::LLVMStructType::getLiteral(
474 indexTy.
getContext(), {elemPtrType, elemPtrType, indexTy, indexTy});
475 auto resultType = LLVM::LLVMPointerType::get(builder.
getContext());
476 return LLVM::GEPOp::create(builder, loc, resultType, structTy,
memRefDescPtr,
485 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
489 return LLVM::LoadOp::create(builder, loc, indexTy, sizeStoreGep);
497 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
501 LLVM::StoreOp::create(builder, loc,
size, sizeStoreGep);
508 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
510 return LLVM::GEPOp::create(builder, loc, ptrType, indexTy,
sizeBasePtr,
rank);
518 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
520 Value strideStoreGep =
522 return LLVM::LoadOp::create(builder, loc, indexTy, strideStoreGep);
530 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
532 Value strideStoreGep =
534 LLVM::StoreOp::create(builder, loc,
stride, strideStoreGep);
static std::pair< Value, Type > castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kPtrInUnrankedMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static constexpr unsigned kRankInUnrankedMemRefDescriptor
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
Conversion from types to the LLVM IR dialect.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Value size(unsigned pos)
Returns the pos-th size Value.
Value alignedPtr()
Returns the aligned pointer Value.
Value stride(unsigned pos)
Returns the pos-th stride Value.
Value allocatedPtr()
Returns the allocated pointer Value.
MemRefDescriptorView(ValueRange range)
Constructs the view from a range of values.
Value offset()
Returns the offset Value.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value offset)
Builds IR inserting the offset into the descriptor.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size)
MemRefDescriptor(Value descriptor)
Construct a helper for the given descriptor value.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size)
Builds IR inserting the pos-th size into the descriptor.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the allocated pointer into the descriptor.
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride)
Builds IR inserting the pos-th stride into the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride)
void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the aligned pointer into the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset)
Builds IR inserting the offset into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
This class helps build Operations.
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
StructBuilder(Value v)
Construct a helper for the given value.
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const
Builds IR to extract a value from the struct at position pos.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static Value stride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR extracting the stride[index] from the descriptor.
static Value size(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index)
Builds IR extracting the size[index] from the descriptor.
UnrankedMemRefDescriptor(Value descriptor)
Construct a helper for the given descriptor value.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.