58 return builder.
create<LLVM::ConstantOp>(loc, resultType,
78 for (
int i = 0, e = indices.size(); i < e; ++i) {
79 Value increment = indices[i];
80 if (strides[i] != 1) {
82 ShapedType::isDynamic(strides[i])
83 ? memRefDescriptor.stride(rewriter, loc, i)
85 increment = rewriter.
create<LLVM::MulOp>(loc, increment, stride);
88 index ? rewriter.
create<LLVM::AddOp>(loc, index, increment) : increment;
91 Type elementPtrType = memRefDescriptor.getElementPtrType();
92 return index ? rewriter.
create<LLVM::GEPOp>(
102 MemRefType type)
const {
105 return type.getLayout().isIdentity();
109 auto elementType = type.getElementType();
122 "layout maps must have been normalized away");
123 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
124 static_cast<ssize_t
>(dynamicSizes.size()) &&
125 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
127 sizes.reserve(memRefType.getRank());
128 unsigned dynamicIndex = 0;
130 for (int64_t size : memRefType.getShape()) {
132 size == ShapedType::kDynamic
133 ? dynamicSizes[dynamicIndex++]
140 strides.resize(memRefType.getRank());
141 for (
auto i = memRefType.getRank(); i-- > 0;) {
142 strides[i] = runningStride;
144 int64_t staticSize = memRefType.getShape()[i];
147 bool useSizeAsStride = stride == 1;
148 if (staticSize == ShapedType::kDynamic)
149 stride = ShapedType::kDynamic;
150 if (stride != ShapedType::kDynamic)
151 stride *= staticSize;
154 runningStride = sizes[i];
155 else if (stride == ShapedType::kDynamic)
157 rewriter.
create<LLVM::MulOp>(loc, runningStride, sizes[i]);
165 Value nullPtr = rewriter.
create<LLVM::ZeroOp>(loc, elementPtrType);
167 loc, elementPtrType, elementType, nullPtr, runningStride);
170 size = runningStride;
183 auto nullPtr = rewriter.
create<LLVM::ZeroOp>(loc, convertedPtrType);
184 auto gep = rewriter.
create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
192 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
193 static_cast<ssize_t
>(dynamicSizes.size()) &&
194 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
197 Value numElements = memRefType.getRank() == 0
200 unsigned dynamicIndex = 0;
203 for (int64_t staticSize : memRefType.getShape()) {
206 staticSize == ShapedType::kDynamic
207 ? dynamicSizes[dynamicIndex++]
209 numElements = rewriter.
create<LLVM::MulOp>(loc, numElements, size);
212 staticSize == ShapedType::kDynamic
213 ? dynamicSizes[dynamicIndex++]
229 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
232 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
236 memRefDescriptor.setOffset(
241 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
245 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
247 return memRefDescriptor;
253 assert(origTypes.size() == operands.size() &&
254 "expected as may original types as operands");
259 for (
unsigned i = 0, e = operands.size(); i < e; ++i) {
260 if (
auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
261 unrankedMemrefs.emplace_back(operands[i]);
266 unrankedAddressSpaces.emplace_back(*addressSpace);
270 if (unrankedMemrefs.empty())
276 unrankedMemrefs, unrankedAddressSpaces,
284 LLVM::LLVMFuncOp freeFunc, mallocFunc;
292 unsigned unrankedMemrefPos = 0;
293 for (
unsigned i = 0, e = operands.size(); i < e; ++i) {
294 Type type = origTypes[i];
295 if (!isa<UnrankedMemRefType>(type))
297 Value allocationSize = sizes[unrankedMemrefPos++];
303 ? builder.
create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
310 builder.
create<LLVM::MemcpyOp>(loc, memory, source, allocationSize,
false);
312 builder.
create<LLVM::CallOp>(loc, freeFunc, source);
324 updatedDesc.setRank(builder, loc, rank);
325 updatedDesc.setMemRefDescPtr(builder, loc, memory);
327 operands[i] = updatedDesc;
348 if (numResults != 0) {
349 resultTypes.push_back(
351 if (!resultTypes.back())
358 resultTypes, targetAttrs);
369 results.reserve(numResults);
370 for (
unsigned i = 0; i < numResults; ++i) {
371 results.push_back(rewriter.
create<LLVM::ExtractValueOp>(
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
const LLVMTypeConverter * getTypeConverter() const
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
This class provides support for representing a failure result, or a valid value of type T.
Conversion from types to the LLVM IR dialect.
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0) const
Creates an LLVM pointer type with the given element type and address space.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp, bool opaquePointers)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType, bool opaquePointers)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.