13 #include "llvm/ADT/ScopeExit.h" 
   14 #include "llvm/Support/Threading.h" 
   30       return *recursiveStack->second;
 
   38   return *recursiveStackInserted.first->second;
 
   48   return values.size() == 1 &&
 
   49          isa<LLVM::LLVMPointerType>(values.front().
getType());
 
   69   assert(resultType && 
"expected non-null result type");
 
   72                                              resultType, inputs[0]);
 
   94   return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
 
  100                                          MemRefType resultType,
 
  110   return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
 
  118     : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), 
options(
options),
 
  120   assert(
llvmDialect && 
"LLVM IR dialect is not registered");
 
  123   addConversion([&](ComplexType type) { 
return convertComplexType(type); });
 
  124   addConversion([&](FloatType type) { 
return convertFloatType(type); });
 
  125   addConversion([&](FunctionType type) { 
return convertFunctionType(type); });
 
  126   addConversion([&](IndexType type) { 
return convertIndexType(type); });
 
  127   addConversion([&](IntegerType type) { 
return convertIntegerType(type); });
 
  128   addConversion([&](MemRefType type) { 
return convertMemRefType(type); });
 
  132     FailureOr<Type> llvmType = convertVectorType(type);
 
  147                     -> std::optional<LogicalResult> {
 
  150       results.push_back(type);
 
  154     if (type.isIdentified()) {
 
  155       auto convertedType = LLVM::LLVMStructType::getIdentified(
 
  156           type.getContext(), (
"_Converted." + type.getName()).str());
 
  159       if (llvm::count(recursiveStack, type)) {
 
  160         results.push_back(convertedType);
 
  163       recursiveStack.push_back(type);
 
  164       auto popConversionCallStack = llvm::make_scope_exit(
 
  165           [&recursiveStack]() { recursiveStack.pop_back(); });
 
  168       convertedElemTypes.reserve(type.getBody().size());
 
  174       if (!convertedType.isInitialized()) {
 
  176                 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
 
  179         results.push_back(convertedType);
 
  187           convertedType.isPacked() == type.isPacked()) {
 
  188         results.push_back(convertedType);
 
  196     convertedSubtypes.reserve(type.getBody().size());
 
  200     results.push_back(LLVM::LLVMStructType::getLiteral(
 
  201         type.getContext(), convertedSubtypes, type.isPacked()));
 
  204   addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
 
  205     if (
auto element = 
convertType(type.getElementType()))
 
  209   addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
 
  211     if (!convertedResType)
 
  215     convertedArgTypes.reserve(type.getNumParams());
 
  227     return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
 
  232     return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
 
  261     if (
auto memrefType = dyn_cast<MemRefType>(originalType))
 
  263     if (
auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
 
  271       [](
BaseMemRefType memref, IntegerAttr addrspace) { 
return addrspace; });
 
  284   return options.
dataLayout.getPointerSizeInBits(addressSpace);
 
  287 Type LLVMTypeConverter::convertIndexType(IndexType type)
 const {
 
  291 Type LLVMTypeConverter::convertIntegerType(IntegerType type)
 const {
 
  295 Type LLVMTypeConverter::convertFloatType(FloatType type)
 const {
 
  301   if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
 
  302           Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
 
  303           Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
 
  304           Float8E8M0FNUType>(type))
 
  316 Type LLVMTypeConverter::convertComplexType(ComplexType type)
 const {
 
  317   auto elementType = 
convertType(type.getElementType());
 
  318   return LLVM::LLVMStructType::getLiteral(&
getContext(),
 
  319                                           {elementType, elementType});
 
  324 Type LLVMTypeConverter::convertFunctionType(FunctionType type)
 const {
 
  334   assert(result.empty() && 
"Unexpected non-empty output");
 
  335   result.resize(funcOp.getNumArguments(), std::nullopt);
 
  336   bool foundByValByRefAttrs = 
false;
 
  337   for (
int argIdx : llvm::seq(funcOp.getNumArguments())) {
 
  339       if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
 
  340            namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
 
  341         foundByValByRefAttrs = 
true;
 
  342         result[argIdx] = namedAttr;
 
  348   if (!foundByValByRefAttrs)
 
  360 Type LLVMTypeConverter::convertFunctionSignatureImpl(
 
  361     FunctionType funcTy, 
bool isVariadic, 
bool useBarePtrCallConv,
 
  362     LLVMTypeConverter::SignatureConversion &result,
 
  363     SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
 const {
 
  372     if (
failed(funcArgConverter(*
this, type, converted)))
 
  377     if (byValRefNonPtrAttrs != 
nullptr && !byValRefNonPtrAttrs->empty() &&
 
  378         converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
 
  381       if (isa<LLVM::LLVMPointerType>(converted[0]))
 
  382         (*byValRefNonPtrAttrs)[idx] = std::nullopt;
 
  387     result.addInputs(idx, converted);
 
  394       funcTy.getNumResults() == 0
 
  404     FunctionType funcTy, 
bool isVariadic, 
bool useBarePtrCallConv,
 
  406   return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
 
  412     FunctionOpInterface funcOp, 
bool isVariadic, 
bool useBarePtrCallConv,
 
  414     SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs)
 const {
 
  419   auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
 
  420   return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
 
  421                                       result, &byValRefNonPtrAttrs);
 
  426 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
 
  430   Type resultType = type.getNumResults() == 0
 
  437   auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
 
  441     inputs.push_back(ptrType);
 
  445   for (
Type t : type.getInputs()) {
 
  449     if (isa<MemRefType, UnrankedMemRefType>(t))
 
  451     inputs.push_back(converted);
 
  487                                              bool unpackAggregates)
 const {
 
  488   if (!type.isStrided()) {
 
  491         "conversion to strided form failed either due to non-strided layout " 
  492         "maps (which should have been normalized away) or other reasons");
 
  501   if (
failed(addressSpace)) {
 
  503               "conversion of memref memory space ")
 
  504         << type.getMemorySpace()
 
  505         << 
" to integer address space " 
  506            "failed. Consider adding memory space conversions.";
 
  514   auto rank = type.getRank();
 
  518   if (unpackAggregates)
 
  519     results.insert(results.end(), 2 * rank, indexTy);
 
  536 Type LLVMTypeConverter::convertMemRefType(MemRefType type)
 const {
 
  543   return LLVM::LLVMStructType::getLiteral(&
getContext(), types);
 
  566 Type LLVMTypeConverter::convertUnrankedMemRefType(
 
  570   return LLVM::LLVMStructType::getLiteral(&
getContext(),
 
  578   std::optional<Attribute> converted =
 
  584   if (
auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
 
  585     if (explicitSpace.getType().isIndex() ||
 
  586         explicitSpace.getType().isSignlessInteger())
 
  587       return explicitSpace.getInt();
 
  594   if (isa<UnrankedMemRefType>(type))
 
  600   auto memrefTy = cast<MemRefType>(type);
 
  601   if (!memrefTy.hasStaticShape())
 
  606   if (
failed(memrefTy.getStridesAndOffset(strides, offset)))
 
  609   for (int64_t stride : strides)
 
  610     if (ShapedType::isDynamic(stride))
 
  613   return ShapedType::isStatic(offset);
 
  636 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type)
 const {
 
  637   auto elementType = 
convertType(type.getElementType());
 
  640   if (type.getShape().empty())
 
  643                                     type.getScalableDims().back());
 
  645          "expected vector type compatible with the LLVM dialect");
 
  649   if (llvm::is_contained(type.getScalableDims().drop_back(), 
true))
 
  651   auto shape = type.getShape();
 
  652   for (
int i = shape.size() - 2; i >= 0; --i)
 
  664   if (useBarePtrCallConv) {
 
  665     if (
auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
 
  666       Type converted = convertMemRefToBarePtr(memrefTy);
 
  669       result.push_back(converted);
 
  682   assert(!types.empty() && 
"expected non-empty list of type");
 
  683   if (types.size() == 1)
 
  687   resultTypes.reserve(types.size());
 
  688   for (
Type type : types) {
 
  692     resultTypes.push_back(converted);
 
  695   return LLVM::LLVMStructType::getLiteral(&
getContext(), resultTypes);
 
  703     TypeRange types, 
bool useBarePtrCallConv,
 
  705     int64_t *numConvertedTypes)
 const {
 
  706   assert(!types.empty() && 
"expected non-empty list of type");
 
  707   assert((!groupedTypes || groupedTypes->empty()) &&
 
  708          "expected groupedTypes to be empty");
 
  712   resultTypes.reserve(types.size());
 
  713   size_t sizeBefore = 0;
 
  714   for (
auto t : types) {
 
  720       llvm::append_range(group, 
ArrayRef(resultTypes).drop_front(sizeBefore));
 
  722     sizeBefore = resultTypes.size();
 
  725   if (numConvertedTypes)
 
  726     *numConvertedTypes = resultTypes.size();
 
  727   if (resultTypes.size() == 1)
 
  728     return resultTypes.front();
 
  729   if (resultTypes.empty())
 
  731   return LLVM::LLVMStructType::getLiteral(&
getContext(), resultTypes);
 
  739   Value one = LLVM::ConstantOp::create(builder, loc, builder.
getI64Type(),
 
  742       LLVM::AllocaOp::create(builder, loc, ptrType, operand.
getType(), one);
 
  744   LLVM::StoreOp::create(builder, loc, operand, allocated);
 
  750     OpBuilder &builder, 
bool useBarePtrCallConv)
 const {
 
  752   for (
size_t i = 0, e = adaptorOperands.size(); i < e; i++)
 
  753     ranges.push_back(adaptorOperands.slice(i, 1));
 
  754   return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv);
 
  759     OpBuilder &builder, 
bool useBarePtrCallConv)
 const {
 
  761   promotedOperands.reserve(adaptorOperands.size());
 
  763   for (
auto [operand, llvmOperand] :
 
  764        llvm::zip_equal(opOperands, adaptorOperands)) {
 
  765     if (useBarePtrCallConv) {
 
  768       if (isa<MemRefType>(operand.getType())) {
 
  769         assert(llvmOperand.size() == 1 && 
"Expected a single operand");
 
  771         promotedOperands.push_back(desc.
alignedPtr(builder, loc));
 
  773       } 
else if (isa<UnrankedMemRefType>(operand.getType())) {
 
  774         llvm_unreachable(
"Unranked memrefs are not supported");
 
  777       if (isa<UnrankedMemRefType>(operand.getType())) {
 
  778         assert(llvmOperand.size() == 1 && 
"Expected a single operand");
 
  783       if (
auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
 
  784         assert(llvmOperand.size() == 1 && 
"Expected a single operand");
 
  791     llvm::append_range(promotedOperands, llvmOperand);
 
  793   return promotedOperands;
 
  803   if (
auto memref = dyn_cast<MemRefType>(type)) {
 
  808     if (converted.empty())
 
  810     result.append(converted.begin(), converted.end());
 
  813   if (isa<UnrankedMemRefType>(type)) {
 
  815     if (converted.empty())
 
  817     result.append(converted.begin(), converted.end());
 
static llvm::ManagedStatic< PassManagerOptions > options
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ArrayRef< ValueRange > adaptorOperands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false, SmallVector< SmallVector< Type >> *groupedTypes=nullptr, int64_t *numConvertedTypes=nullptr) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
LogicalResult convertCallingConventionType(Type type, SmallVectorImpl< Type > &result, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...