14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
31 return *recursiveStack->second;
39 return *recursiveStackInserted.first->second;
49 return values.size() == 1 &&
50 isa<LLVM::LLVMPointerType>(values.front().
getType());
70 assert(resultType &&
"expected non-null result type");
73 resultType, inputs[0]);
95 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, packed)
101 MemRefType resultType,
111 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, packed)
119 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
options(
options),
120 dataLayoutAnalysis(analysis) {
121 assert(
llvmDialect &&
"LLVM IR dialect is not registered");
124 addConversion([&](ComplexType type) {
return convertComplexType(type); });
125 addConversion([&](FloatType type) {
return convertFloatType(type); });
126 addConversion([&](FunctionType type) {
return convertFunctionType(type); });
127 addConversion([&](IndexType type) {
return convertIndexType(type); });
128 addConversion([&](IntegerType type) {
return convertIntegerType(type); });
129 addConversion([&](MemRefType type) {
return convertMemRefType(type); });
133 FailureOr<Type> llvmType = convertVectorType(type);
134 if (failed(llvmType))
148 -> std::optional<LogicalResult> {
151 results.push_back(type);
155 if (type.isIdentified()) {
156 auto convertedType = LLVM::LLVMStructType::getIdentified(
157 type.getContext(), (
"_Converted." + type.getName()).str());
160 if (llvm::count(recursiveStack, type)) {
161 results.push_back(convertedType);
164 recursiveStack.push_back(type);
165 auto popConversionCallStack = llvm::make_scope_exit(
166 [&recursiveStack]() { recursiveStack.pop_back(); });
169 convertedElemTypes.reserve(type.getBody().size());
170 if (failed(
convertTypes(type.getBody(), convertedElemTypes)))
175 if (!convertedType.isInitialized()) {
177 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
180 results.push_back(convertedType);
188 convertedType.isPacked() == type.isPacked()) {
189 results.push_back(convertedType);
197 convertedSubtypes.reserve(type.getBody().size());
198 if (failed(
convertTypes(type.getBody(), convertedSubtypes)))
201 results.push_back(LLVM::LLVMStructType::getLiteral(
202 type.getContext(), convertedSubtypes, type.isPacked()));
205 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
206 if (
auto element =
convertType(type.getElementType()))
210 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
212 if (!convertedResType)
216 convertedArgTypes.reserve(type.getNumParams());
217 if (failed(
convertTypes(type.getParams(), convertedArgTypes)))
228 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, inputs)
233 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, inputs)
262 if (
auto memrefType = dyn_cast<MemRefType>(originalType))
264 if (
auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
272 [](
BaseMemRefType memref, IntegerAttr addrspace) {
return addrspace; });
285 return options.
dataLayout.getPointerSizeInBits(addressSpace);
288 Type LLVMTypeConverter::convertIndexType(IndexType type)
const {
292 Type LLVMTypeConverter::convertIntegerType(IntegerType type)
const {
296 Type LLVMTypeConverter::convertFloatType(FloatType type)
const {
302 if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
303 type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
304 type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
305 type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
306 type.isFloat8E8M0FNU())
318 Type LLVMTypeConverter::convertComplexType(ComplexType type)
const {
319 auto elementType =
convertType(type.getElementType());
320 return LLVM::LLVMStructType::getLiteral(&
getContext(),
321 {elementType, elementType});
326 Type LLVMTypeConverter::convertFunctionType(FunctionType type)
const {
336 assert(result.empty() &&
"Unexpected non-empty output");
337 result.resize(funcOp.getNumArguments(), std::nullopt);
338 bool foundByValByRefAttrs =
false;
339 for (
int argIdx : llvm::seq(funcOp.getNumArguments())) {
341 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
342 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
343 foundByValByRefAttrs =
true;
344 result[argIdx] = namedAttr;
350 if (!foundByValByRefAttrs)
362 Type LLVMTypeConverter::convertFunctionSignatureImpl(
363 FunctionType funcTy,
bool isVariadic,
bool useBarePtrCallConv,
364 LLVMTypeConverter::SignatureConversion &result,
365 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
const {
373 if (failed(funcArgConverter(*
this, type, converted)))
378 if (byValRefNonPtrAttrs !=
nullptr && !byValRefNonPtrAttrs->empty() &&
379 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
382 if (isa<LLVM::LLVMPointerType>(converted[0]))
383 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
388 result.addInputs(idx, converted);
395 funcTy.getNumResults() == 0
405 FunctionType funcTy,
bool isVariadic,
bool useBarePtrCallConv,
407 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
413 FunctionOpInterface funcOp,
bool isVariadic,
bool useBarePtrCallConv,
415 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs)
const {
420 auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
421 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
422 result, &byValRefNonPtrAttrs);
427 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
431 Type resultType = type.getNumResults() == 0
438 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
442 inputs.push_back(ptrType);
446 for (
Type t : type.getInputs()) {
450 if (isa<MemRefType, UnrankedMemRefType>(t))
452 inputs.push_back(converted);
488 bool unpackAggregates)
const {
492 "conversion to strided form failed either due to non-strided layout "
493 "maps (which should have been normalized away) or other reasons");
502 if (failed(addressSpace)) {
504 "conversion of memref memory space ")
505 << type.getMemorySpace()
506 <<
" to integer address space "
507 "failed. Consider adding memory space conversions.";
515 auto rank = type.getRank();
519 if (unpackAggregates)
520 results.insert(results.end(), 2 * rank, indexTy);
537 Type LLVMTypeConverter::convertMemRefType(MemRefType type)
const {
544 return LLVM::LLVMStructType::getLiteral(&
getContext(), types);
567 Type LLVMTypeConverter::convertUnrankedMemRefType(
571 return LLVM::LLVMStructType::getLiteral(&
getContext(),
579 std::optional<Attribute> converted =
585 if (
auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
586 if (explicitSpace.getType().isIndex() ||
587 explicitSpace.getType().isSignlessInteger())
588 return explicitSpace.getInt();
595 if (isa<UnrankedMemRefType>(type))
601 auto memrefTy = cast<MemRefType>(type);
602 if (!memrefTy.hasStaticShape())
610 for (int64_t stride : strides)
611 if (ShapedType::isDynamic(stride))
614 return !ShapedType::isDynamic(offset);
625 if (failed(addressSpace))
637 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type)
const {
638 auto elementType =
convertType(type.getElementType());
641 if (type.getShape().empty())
644 type.getScalableDims().back());
646 "expected vector type compatible with the LLVM dialect");
650 if (llvm::is_contained(type.getScalableDims().drop_back(),
true))
652 auto shape = type.getShape();
653 for (
int i = shape.size() - 2; i >= 0; --i)
664 Type type,
bool useBarePtrCallConv)
const {
665 if (useBarePtrCallConv)
666 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type))
667 return convertMemRefToBarePtr(memrefTy);
678 assert(stdTypes.size() == values.size() &&
679 "The number of types and values doesn't match");
680 for (
unsigned i = 0, end = values.size(); i < end; ++i)
681 if (
auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
683 memrefTy, values[i]);
691 assert(!types.empty() &&
"expected non-empty list of type");
692 if (types.size() == 1)
696 resultTypes.reserve(types.size());
697 for (
Type type : types) {
701 resultTypes.push_back(converted);
704 return LLVM::LLVMStructType::getLiteral(&
getContext(), resultTypes);
712 bool useBarePtrCallConv)
const {
713 assert(!types.empty() &&
"expected non-empty list of type");
716 if (types.size() == 1)
720 resultTypes.reserve(types.size());
721 for (
auto t : types) {
725 resultTypes.push_back(converted);
728 return LLVM::LLVMStructType::getLiteral(&
getContext(), resultTypes);
739 builder.
create<LLVM::AllocaOp>(loc, ptrType, operand.
getType(), one);
741 builder.
create<LLVM::StoreOp>(loc, operand, allocated);
748 bool useBarePtrCallConv)
const {
750 promotedOperands.reserve(operands.size());
752 for (
auto it : llvm::zip(opOperands, operands)) {
753 auto operand = std::get<0>(it);
754 auto llvmOperand = std::get<1>(it);
756 if (useBarePtrCallConv) {
759 if (dyn_cast<MemRefType>(operand.getType())) {
762 }
else if (isa<UnrankedMemRefType>(operand.getType())) {
763 llvm_unreachable(
"Unranked memrefs are not supported");
766 if (isa<UnrankedMemRefType>(operand.getType())) {
771 if (
auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
778 promotedOperands.push_back(llvmOperand);
780 return promotedOperands;
790 if (
auto memref = dyn_cast<MemRefType>(type)) {
795 if (converted.empty())
797 result.append(converted.begin(), converted.end());
800 if (isa<UnrankedMemRefType>(type)) {
802 if (converted.empty())
804 result.append(converted.begin(), converted.end());
810 result.push_back(converted);
824 result.push_back(llvmTy);
static llvm::ManagedStatic< PassManagerOptions > options
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.