13#include "llvm/ADT/ScopeExit.h"
14#include "llvm/Support/Threading.h"
30 return *recursiveStack->second;
38 return *recursiveStackInserted.first->second;
48 return values.size() == 1 &&
49 isa<LLVM::LLVMPointerType>(values.front().
getType());
69 assert(resultType &&
"expected non-null result type");
72 resultType, inputs[0]);
94 return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
100 MemRefType resultType,
110 return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
118 :
llvmDialect(ctx->getOrLoadDialect<
LLVM::LLVMDialect>()), options(options),
119 dataLayoutAnalysis(analysis) {
120 assert(
llvmDialect &&
"LLVM IR dialect is not registered");
123 addConversion([&](ComplexType type) {
return convertComplexType(type); });
124 addConversion([&](FloatType type) {
return convertFloatType(type); });
125 addConversion([&](FunctionType type) {
return convertFunctionType(type); });
126 addConversion([&](IndexType type) {
return convertIndexType(type); });
127 addConversion([&](IntegerType type) {
return convertIntegerType(type); });
128 addConversion([&](MemRefType type) {
return convertMemRefType(type); });
131 addConversion([&](VectorType type) -> std::optional<Type> {
132 FailureOr<Type> llvmType = convertVectorType(type);
133 if (failed(llvmType))
141 addConversion([](
Type type) {
147 -> std::optional<LogicalResult> {
150 results.push_back(type);
154 if (type.isIdentified()) {
155 auto convertedType = LLVM::LLVMStructType::getIdentified(
156 type.getContext(), (
"_Converted." + type.getName()).str());
159 if (llvm::count(recursiveStack, type)) {
160 results.push_back(convertedType);
163 recursiveStack.push_back(type);
164 auto popConversionCallStack = llvm::make_scope_exit(
165 [&recursiveStack]() { recursiveStack.pop_back(); });
168 convertedElemTypes.reserve(type.getBody().size());
169 if (failed(convertTypes(type.getBody(), convertedElemTypes)))
174 if (!convertedType.isInitialized()) {
176 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
179 results.push_back(convertedType);
187 convertedType.isPacked() == type.isPacked()) {
188 results.push_back(convertedType);
196 convertedSubtypes.reserve(type.getBody().size());
197 if (failed(convertTypes(type.getBody(), convertedSubtypes)))
200 results.push_back(LLVM::LLVMStructType::getLiteral(
201 type.getContext(), convertedSubtypes, type.isPacked()));
204 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
205 if (
auto element = convertType(type.getElementType()))
206 return LLVM::LLVMArrayType::get(element, type.getNumElements());
209 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
210 Type convertedResType = convertType(type.getReturnType());
211 if (!convertedResType)
215 convertedArgTypes.reserve(type.getNumParams());
216 if (failed(convertTypes(type.getParams(), convertedArgTypes)))
219 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
225 addSourceMaterialization([&](
OpBuilder &builder,
Type resultType,
227 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
230 addTargetMaterialization([&](
OpBuilder &builder,
Type resultType,
232 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
239 addSourceMaterialization([&](
OpBuilder &builder,
245 addSourceMaterialization([&](
OpBuilder &builder, MemRefType resultType,
251 addTargetMaterialization([&](
OpBuilder &builder,
Type resultType,
259 if (resultType != convertType(originalType))
261 if (
auto memrefType = dyn_cast<MemRefType>(originalType))
263 if (
auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
270 addTypeAttributeConversion(
284 return options.dataLayout.getPointerSizeInBits(addressSpace);
287Type LLVMTypeConverter::convertIndexType(IndexType type)
const {
291Type LLVMTypeConverter::convertIntegerType(IntegerType type)
const {
292 return IntegerType::get(&
getContext(), type.getWidth());
295Type LLVMTypeConverter::convertFloatType(FloatType type)
const {
301 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
302 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
303 Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
304 Float8E8M0FNUType>(type))
305 return IntegerType::get(&
getContext(), type.getWidth());
316Type LLVMTypeConverter::convertComplexType(ComplexType type)
const {
317 auto elementType = convertType(type.getElementType());
318 return LLVM::LLVMStructType::getLiteral(&
getContext(),
319 {elementType, elementType});
324Type LLVMTypeConverter::convertFunctionType(FunctionType type)
const {
325 return LLVM::LLVMPointerType::get(type.getContext());
334 assert(
result.empty() &&
"Unexpected non-empty output");
335 result.resize(funcOp.getNumArguments(), std::nullopt);
336 bool foundByValByRefAttrs =
false;
337 for (
int argIdx : llvm::seq(funcOp.getNumArguments())) {
339 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
340 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
341 foundByValByRefAttrs =
true;
342 result[argIdx] = namedAttr;
348 if (!foundByValByRefAttrs)
360Type LLVMTypeConverter::convertFunctionSignatureImpl(
361 FunctionType funcTy,
bool isVariadic,
bool useBarePtrCallConv,
362 LLVMTypeConverter::SignatureConversion &
result,
363 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
const {
365 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
370 for (
auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
371 SmallVector<Type, 8> converted;
372 if (
failed(funcArgConverter(*
this, type, converted)))
377 if (byValRefNonPtrAttrs !=
nullptr && !byValRefNonPtrAttrs->empty() &&
378 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
381 if (isa<LLVM::LLVMPointerType>(converted[0]))
382 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
384 converted[0] = LLVM::LLVMPointerType::get(&
getContext());
387 result.addInputs(idx, converted);
394 funcTy.getNumResults() == 0
399 return LLVM::LLVMFunctionType::get(resultType,
result.getConvertedTypes(),
404 FunctionType funcTy,
bool isVariadic,
bool useBarePtrCallConv,
405 LLVMTypeConverter::SignatureConversion &
result)
const {
406 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
412 FunctionOpInterface funcOp,
bool isVariadic,
bool useBarePtrCallConv,
413 LLVMTypeConverter::SignatureConversion &
result,
414 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs)
const {
419 auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
420 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
421 result, &byValRefNonPtrAttrs);
426std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
430 Type resultType = type.getNumResults() == 0
436 auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
437 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
441 inputs.push_back(ptrType);
442 resultType = LLVM::LLVMVoidType::get(&
getContext());
445 for (
Type t : type.getInputs()) {
446 auto converted = convertType(t);
449 if (isa<MemRefType, UnrankedMemRefType>(t))
451 inputs.push_back(converted);
454 return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
487 bool unpackAggregates)
const {
488 if (!type.isStrided()) {
490 UnknownLoc::get(type.getContext()),
491 "conversion to strided form failed either due to non-strided layout "
492 "maps (which should have been normalized away) or other reasons");
496 Type elementType = convertType(type.getElementType());
501 if (failed(addressSpace)) {
502 emitError(UnknownLoc::get(type.getContext()),
503 "conversion of memref memory space ")
504 << type.getMemorySpace()
505 <<
" to integer address space "
506 "failed. Consider adding memory space conversions.";
509 auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
514 auto rank = type.getRank();
518 if (unpackAggregates)
519 results.insert(results.end(), 2 * rank, indexTy);
521 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
536Type LLVMTypeConverter::convertMemRefType(MemRefType type)
const {
543 return LLVM::LLVMStructType::getLiteral(&
getContext(), types);
566Type LLVMTypeConverter::convertUnrankedMemRefType(
568 if (!convertType(type.getElementType()))
570 return LLVM::LLVMStructType::getLiteral(&
getContext(),
578 std::optional<Attribute> converted =
584 if (
auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
585 if (explicitSpace.getType().isIndex() ||
586 explicitSpace.getType().isSignlessInteger())
587 return explicitSpace.getInt();
594 if (isa<UnrankedMemRefType>(type))
600 auto memrefTy = cast<MemRefType>(type);
601 if (!memrefTy.hasStaticShape())
606 if (failed(memrefTy.getStridesAndOffset(strides, offset)))
610 if (ShapedType::isDynamic(stride))
613 return ShapedType::isStatic(offset);
626 return LLVM::LLVMPointerType::get(type.
getContext(), *addressSpace);
636FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type)
const {
637 auto elementType = convertType(type.getElementType());
640 if (type.getShape().empty())
641 return VectorType::get({1}, elementType);
642 Type vectorType = VectorType::get(type.getShape().back(), elementType,
643 type.getScalableDims().back());
645 "expected vector type compatible with the LLVM dialect");
649 if (llvm::is_contained(type.getScalableDims().drop_back(),
true))
651 auto shape = type.getShape();
652 for (
int i = shape.size() - 2; i >= 0; --i)
653 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
664 if (useBarePtrCallConv) {
665 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
666 Type converted = convertMemRefToBarePtr(memrefTy);
669 result.push_back(converted);
674 return convertType(type,
result);
682 assert(!types.empty() &&
"expected non-empty list of type");
683 if (types.size() == 1)
684 return convertType(types[0]);
687 resultTypes.reserve(types.size());
688 for (
Type type : types) {
689 Type converted = convertType(type);
692 resultTypes.push_back(converted);
695 return LLVM::LLVMStructType::getLiteral(&
getContext(), resultTypes);
703 TypeRange types,
bool useBarePtrCallConv,
705 int64_t *numConvertedTypes)
const {
706 assert(!types.empty() &&
"expected non-empty list of type");
707 assert((!groupedTypes || groupedTypes->empty()) &&
708 "expected groupedTypes to be empty");
710 useBarePtrCallConv |= options.useBarePtrCallConv;
712 resultTypes.reserve(types.size());
713 size_t sizeBefore = 0;
714 for (
auto t : types) {
720 llvm::append_range(group,
ArrayRef(resultTypes).drop_front(sizeBefore));
722 sizeBefore = resultTypes.size();
725 if (numConvertedTypes)
726 *numConvertedTypes = resultTypes.size();
727 if (resultTypes.size() == 1)
728 return resultTypes.front();
729 if (resultTypes.empty())
731 return LLVM::LLVMStructType::getLiteral(&
getContext(), resultTypes);
738 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
739 Value one = LLVM::ConstantOp::create(builder, loc, builder.
getI64Type(),
742 LLVM::AllocaOp::create(builder, loc, ptrType, operand.
getType(), one);
744 LLVM::StoreOp::create(builder, loc, operand, allocated);
750 OpBuilder &builder,
bool useBarePtrCallConv)
const {
752 for (
size_t i = 0, e = adaptorOperands.size(); i < e; i++)
753 ranges.push_back(adaptorOperands.slice(i, 1));
754 return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv);
759 OpBuilder &builder,
bool useBarePtrCallConv)
const {
761 promotedOperands.reserve(adaptorOperands.size());
762 useBarePtrCallConv |= options.useBarePtrCallConv;
763 for (
auto [operand, llvmOperand] :
764 llvm::zip_equal(opOperands, adaptorOperands)) {
765 if (useBarePtrCallConv) {
768 if (isa<MemRefType>(operand.getType())) {
769 assert(llvmOperand.size() == 1 &&
"Expected a single operand");
771 promotedOperands.push_back(desc.
alignedPtr(builder, loc));
773 }
else if (isa<UnrankedMemRefType>(operand.getType())) {
774 llvm_unreachable(
"Unranked memrefs are not supported");
777 if (isa<UnrankedMemRefType>(operand.getType())) {
778 assert(llvmOperand.size() == 1 &&
"Expected a single operand");
783 if (
auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
784 assert(llvmOperand.size() == 1 &&
"Expected a single operand");
791 llvm::append_range(promotedOperands, llvmOperand);
793 return promotedOperands;
803 if (
auto memref = dyn_cast<MemRefType>(type)) {
808 if (converted.empty())
810 result.append(converted.begin(), converted.end());
813 if (isa<UnrankedMemRefType>(type)) {
815 if (converted.empty())
817 result.append(converted.begin(), converted.end());
820 return converter.convertType(type,
result);
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute > > &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ArrayRef< ValueRange > adaptorOperands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false, SmallVector< SmallVector< Type > > *groupedTypes=nullptr, int64_t *numConvertedTypes=nullptr) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
LogicalResult convertCallingConventionType(Type type, SmallVectorImpl< Type > &result, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.