14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
31 return *recursiveStack->second;
39 return *recursiveStackInserted.first->second;
51 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
options(
options),
52 dataLayoutAnalysis(analysis) {
53 assert(
llvmDialect &&
"LLVM IR dialect is not registered");
56 addConversion([&](ComplexType type) {
return convertComplexType(type); });
58 addConversion([&](FunctionType type) {
return convertFunctionType(type); });
59 addConversion([&](IndexType type) {
return convertIndexType(type); });
60 addConversion([&](IntegerType type) {
return convertIntegerType(type); });
61 addConversion([&](MemRefType type) {
return convertMemRefType(type); });
65 FailureOr<Type> llvmType = convertVectorType(type);
80 -> std::optional<LogicalResult> {
83 results.push_back(type);
87 if (type.isIdentified()) {
89 type.getContext(), (
"_Converted." + type.getName()).str());
92 if (llvm::count(recursiveStack, type)) {
93 results.push_back(convertedType);
96 recursiveStack.push_back(type);
97 auto popConversionCallStack = llvm::make_scope_exit(
98 [&recursiveStack]() { recursiveStack.pop_back(); });
101 convertedElemTypes.reserve(type.getBody().size());
102 if (failed(
convertTypes(type.getBody(), convertedElemTypes)))
107 if (!convertedType.isInitialized()) {
109 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
112 results.push_back(convertedType);
120 convertedType.isPacked() == type.isPacked()) {
121 results.push_back(convertedType);
129 convertedSubtypes.reserve(type.getBody().size());
130 if (failed(
convertTypes(type.getBody(), convertedSubtypes)))
134 type.getContext(), convertedSubtypes, type.isPacked()));
137 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138 if (
auto element =
convertType(type.getElementType()))
142 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
144 if (!convertedResType)
148 convertedArgTypes.reserve(type.getNumParams());
149 if (failed(
convertTypes(type.getParams(), convertedArgTypes)))
164 if (inputs.size() == 1) {
174 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, desc)
180 if (inputs.size() == 1) {
183 BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
186 Block *block = barePtr.getOwner();
187 if (!block->isEntryBlock() ||
188 !isa<FunctionOpInterface>(block->getParentOp()))
190 desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
193 desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
198 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, desc)
205 if (inputs.size() != 1)
208 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, inputs)
213 if (inputs.size() != 1)
216 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, inputs)
222 [](
BaseMemRefType memref, IntegerAttr addrspace) {
return addrspace; });
235 return options.
dataLayout.getPointerSizeInBits(addressSpace);
238 Type LLVMTypeConverter::convertIndexType(IndexType type)
const {
242 Type LLVMTypeConverter::convertIntegerType(IntegerType type)
const {
246 Type LLVMTypeConverter::convertFloatType(
FloatType type)
const {
260 Type LLVMTypeConverter::convertComplexType(ComplexType type)
const {
261 auto elementType =
convertType(type.getElementType());
263 {elementType, elementType});
268 Type LLVMTypeConverter::convertFunctionType(FunctionType type)
const {
278 assert(result.empty() &&
"Unexpected non-empty output");
279 result.resize(funcOp.getNumArguments(), std::nullopt);
280 bool foundByValByRefAttrs =
false;
281 for (
int argIdx : llvm::seq(funcOp.getNumArguments())) {
283 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
284 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
285 foundByValByRefAttrs =
true;
286 result[argIdx] = namedAttr;
292 if (!foundByValByRefAttrs)
304 Type LLVMTypeConverter::convertFunctionSignatureImpl(
305 FunctionType funcTy,
bool isVariadic,
bool useBarePtrCallConv,
306 LLVMTypeConverter::SignatureConversion &result,
307 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
const {
315 if (failed(funcArgConverter(*
this, type, converted)))
320 if (byValRefNonPtrAttrs !=
nullptr && !byValRefNonPtrAttrs->empty() &&
321 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
324 if (isa<LLVM::LLVMPointerType>(converted[0]))
325 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
330 result.addInputs(idx, converted);
337 funcTy.getNumResults() == 0
347 FunctionType funcTy,
bool isVariadic,
bool useBarePtrCallConv,
349 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
355 FunctionOpInterface funcOp,
bool isVariadic,
bool useBarePtrCallConv,
357 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs)
const {
362 auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
363 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
364 result, &byValRefNonPtrAttrs);
369 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
373 Type resultType = type.getNumResults() == 0
380 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
384 inputs.push_back(ptrType);
388 for (
Type t : type.getInputs()) {
392 if (isa<MemRefType, UnrankedMemRefType>(t))
394 inputs.push_back(converted);
429 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
430 bool unpackAggregates)
const {
434 "conversion to strided form failed either due to non-strided layout "
435 "maps (which should have been normalized away) or other reasons");
444 if (failed(addressSpace)) {
446 "conversion of memref memory space ")
447 << type.getMemorySpace()
448 <<
" to integer address space "
449 "failed. Consider adding memory space conversions.";
457 auto rank = type.getRank();
461 if (unpackAggregates)
462 results.insert(results.end(), 2 * rank, indexTy);
479 Type LLVMTypeConverter::convertMemRefType(MemRefType type)
const {
483 getMemRefDescriptorFields(type,
false);
497 LLVMTypeConverter::getUnrankedMemRefDescriptorFields()
const {
509 Type LLVMTypeConverter::convertUnrankedMemRefType(
514 getUnrankedMemRefDescriptorFields());
521 std::optional<Attribute> converted =
527 if (
auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
528 if (explicitSpace.getType().isIndex() ||
529 explicitSpace.getType().isSignlessInteger())
530 return explicitSpace.getInt();
537 if (isa<UnrankedMemRefType>(type))
543 auto memrefTy = cast<MemRefType>(type);
544 if (!memrefTy.hasStaticShape())
552 for (int64_t stride : strides)
553 if (ShapedType::isDynamic(stride))
556 return !ShapedType::isDynamic(offset);
567 if (failed(addressSpace))
579 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type)
const {
580 auto elementType =
convertType(type.getElementType());
583 if (type.getShape().empty())
586 type.getScalableDims().back());
588 "expected vector type compatible with the LLVM dialect");
592 if (llvm::is_contained(type.getScalableDims().drop_back(),
true))
594 auto shape = type.getShape();
595 for (
int i = shape.size() - 2; i >= 0; --i)
606 Type type,
bool useBarePtrCallConv)
const {
607 if (useBarePtrCallConv)
608 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type))
609 return convertMemRefToBarePtr(memrefTy);
620 assert(stdTypes.size() == values.size() &&
621 "The number of types and values doesn't match");
622 for (
unsigned i = 0, end = values.size(); i < end; ++i)
623 if (
auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
625 memrefTy, values[i]);
633 assert(!types.empty() &&
"expected non-empty list of type");
634 if (types.size() == 1)
638 resultTypes.reserve(types.size());
639 for (
Type type : types) {
643 resultTypes.push_back(converted);
654 bool useBarePtrCallConv)
const {
655 assert(!types.empty() &&
"expected non-empty list of type");
658 if (types.size() == 1)
662 resultTypes.reserve(types.size());
663 for (
auto t : types) {
667 resultTypes.push_back(converted);
681 builder.
create<LLVM::AllocaOp>(loc, ptrType, operand.
getType(), one);
683 builder.
create<LLVM::StoreOp>(loc, operand, allocated);
690 bool useBarePtrCallConv)
const {
692 promotedOperands.reserve(operands.size());
694 for (
auto it : llvm::zip(opOperands, operands)) {
695 auto operand = std::get<0>(it);
696 auto llvmOperand = std::get<1>(it);
698 if (useBarePtrCallConv) {
701 if (dyn_cast<MemRefType>(operand.getType())) {
704 }
else if (isa<UnrankedMemRefType>(operand.getType())) {
705 llvm_unreachable(
"Unranked memrefs are not supported");
708 if (isa<UnrankedMemRefType>(operand.getType())) {
713 if (
auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
720 promotedOperands.push_back(llvmOperand);
722 return promotedOperands;
732 if (
auto memref = dyn_cast<MemRefType>(type)) {
736 converter.getMemRefDescriptorFields(memref,
true);
737 if (converted.empty())
739 result.append(converted.begin(), converted.end());
742 if (isa<UnrankedMemRefType>(type)) {
743 auto converted = converter.getUnrankedMemRefDescriptorFields();
744 if (converted.empty())
746 result.append(converted.begin(), converted.end());
752 result.push_back(converted);
766 result.push_back(llvmTy);
static llvm::ManagedStatic< PassManagerOptions > options
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
unsigned getWidth()
Return the bitwidth of this float type.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
static LLVMStructType getIdentified(MLIRContext *context, StringRef name)
Gets or creates an identified struct with the given name in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isFloat8E4M3FN() const
bool isFloat8E3M4() const
bool isFloat8E4M3FNUZ() const
bool isFloat8E4M3B11FNUZ() const
bool isFloat6E3M2FN() const
bool isFloat8E5M2() const
bool isFloat8E8M0FNU() const
bool isFloat4E2M1FN() const
bool isFloat6E2M3FN() const
bool isFloat8E4M3() const
bool isFloat8E5M2FNUZ() const
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.