15#include "llvm/Support/CheckedArithmetic.h"
16#include "llvm/Support/MathExtras.h"
27 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
31 ConversionPattern::getTypeConverter());
62 return LLVM::ConstantOp::create(builder, loc, resultType,
67 ConversionPatternRewriter &rewriter,
Location loc, MemRefType type,
69 LLVM::GEPNoWrapFlags noWrapFlags)
const {
71 memRefDesc,
indices, noWrapFlags);
77 MemRefType type)
const {
78 if (!type.getLayout().isIdentity())
80 return static_cast<bool>(typeConverter->convertType(type));
85 if (failed(addressSpace))
87 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
95 "layout maps must have been normalized away");
96 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
97 static_cast<ssize_t
>(dynamicSizes.size()) &&
98 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
100 sizes.reserve(memRefType.getRank());
101 unsigned dynamicIndex = 0;
103 for (
int64_t size : memRefType.getShape()) {
105 size == ShapedType::kDynamic
106 ? dynamicSizes[dynamicIndex++]
112 bool overflowed =
false;
114 strides.resize(memRefType.getRank());
115 for (
auto i = memRefType.getRank(); i-- > 0;) {
116 strides[i] = overflowed ? LLVM::PoisonOp::create(rewriter, loc, indexType)
119 int64_t staticSize = memRefType.getShape()[i];
120 bool useSizeAsStride = stride == 1;
121 if (staticSize == ShapedType::kDynamic)
122 stride = ShapedType::kDynamic;
123 if (stride != ShapedType::kDynamic) {
124 std::optional<int64_t> res = llvm::checkedMul(stride, staticSize);
129 stride = res.value();
133 runningStride = LLVM::PoisonOp::create(rewriter, loc, indexType);
134 else if (useSizeAsStride)
135 runningStride = sizes[i];
136 else if (stride == ShapedType::kDynamic)
138 LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
144 Type elementType = typeConverter->convertType(memRefType.getElementType());
145 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
146 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
147 Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
148 elementType, nullPtr, runningStride);
149 size = LLVM::PtrToIntOp::create(rewriter, loc,
getIndexType(), gepPtr);
151 size = runningStride;
156 Location loc,
Type type, ConversionPatternRewriter &rewriter)
const {
162 Type llvmType = typeConverter->convertType(type);
163 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
164 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
165 auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
167 return LLVM::PtrToIntOp::create(rewriter, loc,
getIndexType(), gep);
172 ConversionPatternRewriter &rewriter)
const {
173 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
174 static_cast<ssize_t
>(dynamicSizes.size()) &&
175 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
178 Value numElements = memRefType.getRank() == 0
181 unsigned dynamicIndex = 0;
184 for (
int64_t staticSize : memRefType.getShape()) {
187 staticSize == ShapedType::kDynamic
188 ? dynamicSizes[dynamicIndex++]
190 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
193 staticSize == ShapedType::kDynamic
194 ? dynamicSizes[dynamicIndex++]
205 ConversionPatternRewriter &rewriter)
const {
206 auto structType = typeConverter->convertType(memRefType);
210 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
213 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
217 memRefDescriptor.setOffset(
221 for (
const auto &en : llvm::enumerate(sizes))
222 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
225 for (
const auto &en : llvm::enumerate(strides))
226 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
228 return memRefDescriptor;
233 Value operand,
bool toDynamic)
const {
235 FailureOr<unsigned> addressSpace =
237 if (failed(addressSpace))
244 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
245 FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
248 if (failed(mallocFunc))
253 if (failed(freeFunc))
262 Value memory = toDynamic
263 ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
266 : LLVM::AllocaOp::create(builder, loc,
getPtrType(),
271 LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize,
false);
273 LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
285 updatedDesc.setRank(builder, loc, rank);
286 updatedDesc.setMemRefDescPtr(builder, loc, memory);
293 assert(origTypes.size() == operands.size() &&
294 "expected as may original types as operands");
295 for (
unsigned i = 0, e = operands.size(); i < e; ++i) {
296 if (
auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
298 operands[i], toDynamic);
301 operands[i] = updatedDesc;
317 ConversionPatternRewriter &rewriter) {
321 if (numResults != 0) {
322 resultTypes.push_back(
324 if (!resultTypes.back())
330 resultTypes, targetAttrs);
332 Operation *newOp = rewriter.create(state);
336 return rewriter.eraseOp(op),
success();
343 results.reserve(numResults);
344 for (
unsigned i = 0; i < numResults; ++i) {
345 results.push_back(LLVM::ExtractValueOp::create(rewriter, op->
getLoc(),
348 rewriter.replaceOp(op, results);
357 if (!llvm::all_of(operands, [](
Value value) {
367 auto callIntrOp = LLVM::CallIntrinsicOp::create(
368 rewriter, loc, resType, rewriter.
getStringAttr(intrinsic), operands);
372 if (numResults <= 1) {
381 results.reserve(numResults);
382 Value intrRes = callIntrOp.getResults();
383 for (
unsigned i = 0; i < numResults; ++i)
384 results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
394 auto vec = cast<VectorType>(type);
395 assert(!vec.isScalable() &&
"scalable vectors are not supported");
396 return vec.getNumElements() *
getBitWidth(vec.getElementType());
404 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
406 if (
auto structType = dyn_cast<LLVM::LLVMStructType>(type))
407 return llvm::all_of(structType.getBody(), [&](
Type fieldType) {
408 return isFixedSizeAggregate(fieldType, dstType);
410 if (
auto vecTy = dyn_cast<VectorType>(type))
411 return !vecTy.isScalable();
418 return LLVM::ConstantOp::create(builder, loc, i32, value);
427 if (srcType == dstType) {
432 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(srcType)) {
433 for (
auto i : llvm::seq(arrayType.getNumElements())) {
434 Value elem = LLVM::ExtractValueOp::create(builder, loc, src, i);
440 if (
auto structType = dyn_cast<LLVM::LLVMStructType>(srcType)) {
441 for (
auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
442 Value field = LLVM::ExtractValueOp::create(builder, loc, src,
450 if (!srcType.
isIntOrFloat() && !isa<VectorType>(srcType)) {
457 if (srcBitWidth == dstBitWidth) {
458 Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
463 if (dstBitWidth > srcBitWidth) {
465 if (srcType != smallerInt)
466 src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
469 Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
473 int64_t numElements = llvm::divideCeil(srcBitWidth, dstBitWidth);
474 int64_t roundedBitWidth = numElements * dstBitWidth;
477 if (roundedBitWidth != srcBitWidth) {
479 if (srcType != srcInt)
480 src = LLVM::BitcastOp::create(builder, loc, srcInt, src);
482 src = LLVM::ZExtOp::create(builder, loc, roundedInt, src);
485 auto vecType = VectorType::get(numElements, dstType);
486 src = LLVM::BitcastOp::create(builder, loc, vecType, src);
488 for (
auto i : llvm::seq(numElements)) {
490 Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
498 bool permitVariablySizedScalars) {
501 if (!permitVariablySizedScalars &&
512 size_t &offset,
Type dstType) {
513 if (
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(dstType)) {
514 Value result = LLVM::PoisonOp::create(builder, loc, arrayType);
515 Type elemType = arrayType.getElementType();
516 for (
auto i : llvm::seq(arrayType.getNumElements())) {
518 result = LLVM::InsertValueOp::create(builder, loc,
result, elem, i);
523 if (
auto structType = dyn_cast<LLVM::LLVMStructType>(dstType)) {
524 Value result = LLVM::PoisonOp::create(builder, loc, structType);
525 for (
auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
527 result = LLVM::InsertValueOp::create(builder, loc,
result, field,
534 if (!dstType.
isIntOrFloat() && !isa<VectorType>(dstType))
535 return src[offset++];
539 Value front = src[offset];
540 if (front.
getType() == dstType) {
548 if (srcBitWidth >= dstBitWidth) {
551 if (dstBitWidth < srcBitWidth) {
553 if (res.
getType() != largerInt)
554 res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
557 res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
560 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
567 int64_t numElements = llvm::divideCeil(dstBitWidth, elemBitWidth);
568 int64_t roundedBitWidth = numElements * elemBitWidth;
570 auto vecType = VectorType::get(numElements, front.
getType());
571 Value res = LLVM::PoisonOp::create(builder, loc, vecType);
572 for (
auto i : llvm::seq(numElements)) {
574 res = LLVM::InsertElementOp::create(builder, loc, vecType, res,
579 if (roundedBitWidth != dstBitWidth) {
581 res = LLVM::BitcastOp::create(builder, loc, roundedInt, res);
583 res = LLVM::TruncOp::create(builder, loc, dstInt, res);
584 if (dstType != dstInt)
585 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
588 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
596 assert(!src.empty() &&
"src range must not be empty");
599 assert(offset == src.size() &&
"not all decomposed values were consumed");
605 MemRefType type,
Value memRefDesc,
607 LLVM::GEPNoWrapFlags noWrapFlags) {
608 auto [strides, offset] = type.getStridesAndOffset();
615 Value base = memRefDescriptor.
bufferPtr(builder, loc, converter, type);
617 LLVM::IntegerOverflowFlags intOverflowFlags =
618 LLVM::IntegerOverflowFlags::none;
619 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
620 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
622 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
623 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
628 for (
int i = 0, e =
indices.size(); i < e; ++i) {
630 if (strides[i] != 1) {
632 ShapedType::isDynamic(strides[i])
633 ? memRefDescriptor.
stride(builder, loc, i)
634 : LLVM::ConstantOp::create(builder, loc, indexType,
636 increment = LLVM::MulOp::create(builder, loc, increment, stride,
646 ? LLVM::GEPOp::create(builder, loc, elementPtrType,
647 converter.convertType(type.getElementType()),
648 base,
index, noWrapFlags)
655 if (
auto floatType = dyn_cast<FloatType>(type))
657 if (
auto vecType = dyn_cast<VectorType>(type))
658 return dyn_cast<FloatType>(vecType.getElementType());
667 Type convertedType = typeConverter.convertType(floatType);
670 return !isa<FloatType>(convertedType);
679 return isUnsupportedFloatingPointType(typeConverter, r.getType());
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static unsigned getBitWidth(Type type)
static FloatType getFloatingPointType(Type type)
Return the given type if it's a floating point type.
static bool isFixedSizeAggregate(Type type, Type dstType)
Returns true if every leaf in type (recursing through LLVM arrays and structs) is either equal to dst...
static Value composeValueImpl(OpBuilder &builder, Location loc, ValueRange src, size_t &offset, Type dstType)
Recursive implementation of composeValue.
static void decomposeValueImpl(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result)
Recursive implementation of decomposeValue.
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
const LLVMTypeConverter * getTypeConverter() const
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const
Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to sta...
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Conversion from types to the LLVM IR dialect.
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result, bool permitVariablySizedScalars=false)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.