14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
31 return *recursiveStack->second;
39 return *recursiveStackInserted.first->second;
51 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
options(
options),
52 dataLayoutAnalysis(analysis) {
53 assert(
llvmDialect &&
"LLVM IR dialect is not registered");
56 addConversion([&](ComplexType type) {
return convertComplexType(type); });
58 addConversion([&](FunctionType type) {
return convertFunctionType(type); });
59 addConversion([&](IndexType type) {
return convertIndexType(type); });
60 addConversion([&](IntegerType type) {
return convertIntegerType(type); });
61 addConversion([&](MemRefType type) {
return convertMemRefType(type); });
81 addConversion([&](LLVM::LLVMPointerType type) -> std::optional<Type> {
84 if (
auto pointee =
convertType(type.getElementType()))
90 -> std::optional<LogicalResult> {
93 results.push_back(type);
97 if (type.isIdentified()) {
99 type.getContext(), (
"_Converted_" + type.getName()).str());
100 unsigned counter = 1;
101 while (convertedType.isInitialized()) {
102 assert(counter != UINT_MAX &&
103 "about to overflow struct renaming counter in conversion");
106 (
"_Converted_" + std::to_string(counter) + type.getName()).str());
110 if (llvm::count(recursiveStack, type)) {
111 results.push_back(convertedType);
114 recursiveStack.push_back(type);
115 auto popConversionCallStack = llvm::make_scope_exit(
116 [&recursiveStack]() { recursiveStack.pop_back(); });
119 convertedElemTypes.reserve(type.getBody().size());
123 if (
failed(convertedType.setBody(convertedElemTypes, type.isPacked())))
125 results.push_back(convertedType);
130 convertedSubtypes.reserve(type.getBody().size());
135 type.getContext(), convertedSubtypes, type.isPacked()));
138 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
139 if (
auto element =
convertType(type.getElementType()))
143 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
145 if (!convertedResType)
149 convertedArgTypes.reserve(type.getNumParams());
162 Location loc) -> std::optional<Value> {
163 if (inputs.size() == 1)
170 Location loc) -> std::optional<Value> {
173 if (inputs.size() == 1)
181 Location loc) -> std::optional<Value> {
182 if (inputs.size() != 1)
185 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, inputs)
190 Location loc) -> std::optional<Value> {
191 if (inputs.size() != 1)
194 return builder.
create<UnrealizedConversionCastOp>(loc, resultType, inputs)
200 [](
BaseMemRefType memref, IntegerAttr addrspace) {
return addrspace; });
212 LLVM::LLVMPointerType
214 unsigned int addressSpace)
const {
221 return options.
dataLayout.getPointerSizeInBits(addressSpace);
224 Type LLVMTypeConverter::convertIndexType(IndexType type)
const {
228 Type LLVMTypeConverter::convertIntegerType(IntegerType type)
const {
232 Type LLVMTypeConverter::convertFloatType(
FloatType type)
const {
243 Type LLVMTypeConverter::convertComplexType(ComplexType type)
const {
244 auto elementType =
convertType(type.getElementType());
246 {elementType, elementType});
251 Type LLVMTypeConverter::convertFunctionType(FunctionType type)
const {
252 SignatureConversion conversion(type.getNumInputs());
265 FunctionType funcTy,
bool isVariadic,
bool useBarePtrCallConv,
274 if (
failed(funcArgConverter(*
this, type, converted)))
283 funcTy.getNumResults() == 0
294 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
298 Type resultType = type.getNumResults() == 0
304 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
312 for (
Type t : type.getInputs()) {
316 if (isa<MemRefType, UnrankedMemRefType>(t))
318 inputs.push_back(converted);
353 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
354 bool unpackAggregates)
const {
358 "conversion to strided form failed either due to non-strided layout "
359 "maps (which should have been normalized away) or other reasons");
368 if (
failed(addressSpace)) {
370 "conversion of memref memory space ")
371 << type.getMemorySpace()
372 <<
" to integer address space "
373 "failed. Consider adding memory space conversions.";
381 auto rank = type.getRank();
385 if (unpackAggregates)
386 results.insert(results.end(), 2 * rank, indexTy);
403 Type LLVMTypeConverter::convertMemRefType(MemRefType type)
const {
407 getMemRefDescriptorFields(type,
false);
421 LLVMTypeConverter::getUnrankedMemRefDescriptorFields()
const {
433 Type LLVMTypeConverter::convertUnrankedMemRefType(
438 getUnrankedMemRefDescriptorFields());
445 std::optional<Attribute> converted =
451 if (
auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
452 return explicitSpace.getInt();
458 if (isa<UnrankedMemRefType>(type))
464 auto memrefTy = cast<MemRefType>(type);
465 if (!memrefTy.hasStaticShape())
473 for (int64_t stride : strides)
474 if (ShapedType::isDynamic(stride))
477 return !ShapedType::isDynamic(offset);
500 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type)
const {
501 auto elementType =
convertType(type.getElementType());
504 if (type.getShape().empty())
507 type.getScalableDims().back());
509 "expected vector type compatible with the LLVM dialect");
511 if (llvm::is_contained(type.getScalableDims().drop_back(),
true))
513 auto shape = type.getShape();
514 for (
int i = shape.size() - 2; i >= 0; --i)
525 Type type,
bool useBarePtrCallConv)
const {
526 if (useBarePtrCallConv)
527 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type))
528 return convertMemRefToBarePtr(memrefTy);
539 assert(stdTypes.size() == values.size() &&
540 "The number of types and values doesn't match");
541 for (
unsigned i = 0, end = values.size(); i < end; ++i)
542 if (
auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
544 memrefTy, values[i]);
552 assert(!types.empty() &&
"expected non-empty list of type");
553 if (types.size() == 1)
557 resultTypes.reserve(types.size());
558 for (
Type type : types) {
562 resultTypes.push_back(converted);
573 bool useBarePtrCallConv)
const {
574 assert(!types.empty() &&
"expected non-empty list of type");
577 if (types.size() == 1)
581 resultTypes.reserve(types.size());
582 for (
auto t : types) {
586 resultTypes.push_back(converted);
600 builder.
create<LLVM::AllocaOp>(loc, ptrType, operand.
getType(), one);
602 builder.
create<LLVM::StoreOp>(loc, operand, allocated);
609 bool useBarePtrCallConv)
const {
611 promotedOperands.reserve(operands.size());
613 for (
auto it : llvm::zip(opOperands, operands)) {
614 auto operand = std::get<0>(it);
615 auto llvmOperand = std::get<1>(it);
617 if (useBarePtrCallConv) {
620 if (dyn_cast<MemRefType>(operand.getType())) {
623 }
else if (isa<UnrankedMemRefType>(operand.getType())) {
624 llvm_unreachable(
"Unranked memrefs are not supported");
627 if (isa<UnrankedMemRefType>(operand.getType())) {
632 if (
auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
639 promotedOperands.push_back(llvmOperand);
641 return promotedOperands;
651 if (
auto memref = dyn_cast<MemRefType>(type)) {
655 converter.getMemRefDescriptorFields(memref,
true);
656 if (converted.empty())
658 result.append(converted.begin(), converted.end());
661 if (isa<UnrankedMemRefType>(type)) {
662 auto converted = converter.getUnrankedMemRefDescriptorFields();
663 if (converted.empty())
665 result.append(converted.begin(), converted.end());
671 result.push_back(converted);
685 result.push_back(llvmTy);
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
unsigned getWidth()
Return the bitwidth of this float type.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0) const
Creates an LLVM pointer type with the given element type and address space.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
bool useOpaquePointers() const
Returns true if using opaque pointers was enabled in the lowering options.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
static LLVMStructType getIdentified(MLIRContext *context, StringRef name)
Gets or creates an identified struct with the given name in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat8E4M3FN() const
bool isFloat8E4M3FNUZ() const
bool isFloat8E4M3B11FNUZ() const
bool isFloat8E5M2() const
bool isFloat8E5M2FNUZ() const
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.