MLIR 22.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1//===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "llvm/ADT/ScopeExit.h"
14#include "llvm/Support/Threading.h"
15#include <memory>
16#include <mutex>
17#include <optional>
18
19using namespace mlir;
20
22 {
23 // Most of the time, the entry already exists in the map.
24 std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
25 std::defer_lock);
26 if (getContext().isMultithreadingEnabled())
27 lock.lock();
28 auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
29 if (recursiveStack != conversionCallStack.end())
30 return *recursiveStack->second;
31 }
32
33 // First time this thread gets here, we have to get an exclusive access to
34 // inset in the map
35 std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
36 auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
37 llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
38 return *recursiveStackInserted.first->second;
39}
40
41/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
45
46/// Helper function that checks if the given value range is a bare pointer.
47static bool isBarePointer(ValueRange values) {
48 return values.size() == 1 &&
49 isa<LLVM::LLVMPointerType>(values.front().getType());
50}
51
52/// Pack SSA values into an unranked memref descriptor struct.
54 UnrankedMemRefType resultType,
55 ValueRange inputs, Location loc,
56 const LLVMTypeConverter &converter) {
57 // Note: Bare pointers are not supported for unranked memrefs because a
58 // memref descriptor cannot be built just from a bare pointer.
59 if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
60 return Value();
61 return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
62 inputs);
63}
64
65/// Pack SSA values into a ranked memref descriptor struct.
66static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
67 ValueRange inputs, Location loc,
68 const LLVMTypeConverter &converter) {
69 assert(resultType && "expected non-null result type");
70 if (isBarePointer(inputs))
71 return MemRefDescriptor::fromStaticShape(builder, loc, converter,
72 resultType, inputs[0]);
73 if (TypeRange(inputs) ==
74 converter.getMemRefDescriptorFields(resultType,
75 /*unpackAggregates=*/true))
76 return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
77 // The inputs are neither a bare pointer nor an unpacked memref descriptor.
78 // This materialization function cannot be used.
79 return Value();
80}
81
82/// MemRef descriptor elements -> UnrankedMemRefType
84 UnrankedMemRefType resultType,
85 ValueRange inputs, Location loc,
86 const LLVMTypeConverter &converter) {
87 // A source materialization must return a value of type
88 // `resultType`, so insert a cast from the memref descriptor type
89 // (!llvm.struct) to the original memref type.
90 Value packed =
91 packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
92 if (!packed)
93 return Value();
94 return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
95 .getResult(0);
96}
97
98/// MemRef descriptor elements -> MemRefType
100 MemRefType resultType,
101 ValueRange inputs, Location loc,
102 const LLVMTypeConverter &converter) {
103 // A source materialization must return a value of type `resultType`,
104 // so insert a cast from the memref descriptor type (!llvm.struct) to the
105 // original memref type.
106 Value packed =
107 packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
108 if (!packed)
109 return Value();
110 return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
111 .getResult(0);
112}
113
114/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
116 const LowerToLLVMOptions &options,
117 const DataLayoutAnalysis *analysis)
118 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
119 dataLayoutAnalysis(analysis) {
120 assert(llvmDialect && "LLVM IR dialect is not registered");
121
122 // Register conversions for the builtin types.
123 addConversion([&](ComplexType type) { return convertComplexType(type); });
124 addConversion([&](FloatType type) { return convertFloatType(type); });
125 addConversion([&](FunctionType type) { return convertFunctionType(type); });
126 addConversion([&](IndexType type) { return convertIndexType(type); });
127 addConversion([&](IntegerType type) { return convertIntegerType(type); });
128 addConversion([&](MemRefType type) { return convertMemRefType(type); });
129 addConversion(
130 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
131 addConversion([&](VectorType type) -> std::optional<Type> {
132 FailureOr<Type> llvmType = convertVectorType(type);
133 if (failed(llvmType))
134 return std::nullopt;
135 return llvmType;
136 });
137
138 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
139 // before the conversions below since conversions are attempted in reverse
140 // order and those should take priority.
141 addConversion([](Type type) {
142 return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
143 : std::nullopt;
144 });
145
146 addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
147 -> std::optional<LogicalResult> {
148 // Fastpath for types that won't be converted by this callback anyway.
149 if (LLVM::isCompatibleType(type)) {
150 results.push_back(type);
151 return success();
152 }
153
154 if (type.isIdentified()) {
155 auto convertedType = LLVM::LLVMStructType::getIdentified(
156 type.getContext(), ("_Converted." + type.getName()).str());
157
159 if (llvm::count(recursiveStack, type)) {
160 results.push_back(convertedType);
161 return success();
162 }
163 recursiveStack.push_back(type);
164 auto popConversionCallStack = llvm::make_scope_exit(
165 [&recursiveStack]() { recursiveStack.pop_back(); });
166
167 SmallVector<Type> convertedElemTypes;
168 convertedElemTypes.reserve(type.getBody().size());
169 if (failed(convertTypes(type.getBody(), convertedElemTypes)))
170 return std::nullopt;
171
172 // If the converted type has not been initialized yet, just set its body
173 // to be the converted arguments and return.
174 if (!convertedType.isInitialized()) {
175 if (failed(
176 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
177 return failure();
178 }
179 results.push_back(convertedType);
180 return success();
181 }
182
183 // If it has been initialized, has the same body and packed bit, just use
184 // it. This ensures that recursive structs keep being recursive rather
185 // than including a non-updated name.
186 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
187 convertedType.isPacked() == type.isPacked()) {
188 results.push_back(convertedType);
189 return success();
190 }
191
192 return failure();
193 }
194
195 SmallVector<Type> convertedSubtypes;
196 convertedSubtypes.reserve(type.getBody().size());
197 if (failed(convertTypes(type.getBody(), convertedSubtypes)))
198 return std::nullopt;
199
200 results.push_back(LLVM::LLVMStructType::getLiteral(
201 type.getContext(), convertedSubtypes, type.isPacked()));
202 return success();
203 });
204 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
205 if (auto element = convertType(type.getElementType()))
206 return LLVM::LLVMArrayType::get(element, type.getNumElements());
207 return std::nullopt;
208 });
209 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
210 Type convertedResType = convertType(type.getReturnType());
211 if (!convertedResType)
212 return std::nullopt;
213
214 SmallVector<Type> convertedArgTypes;
215 convertedArgTypes.reserve(type.getNumParams());
216 if (failed(convertTypes(type.getParams(), convertedArgTypes)))
217 return std::nullopt;
218
219 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
220 type.isVarArg());
221 });
222
223 // Add generic source and target materializations to handle cases where
224 // non-LLVM types persist after an LLVM conversion.
225 addSourceMaterialization([&](OpBuilder &builder, Type resultType,
226 ValueRange inputs, Location loc) {
227 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
228 .getResult(0);
229 });
230 addTargetMaterialization([&](OpBuilder &builder, Type resultType,
231 ValueRange inputs, Location loc) {
232 return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
233 .getResult(0);
234 });
235
236 // Source materializations convert from the new block argument types
237 // (multiple SSA values that make up a memref descriptor) back to the
238 // original block argument type.
239 addSourceMaterialization([&](OpBuilder &builder,
240 UnrankedMemRefType resultType, ValueRange inputs,
241 Location loc) {
242 return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
243 *this);
244 });
245 addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
246 ValueRange inputs, Location loc) {
247 return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
248 });
249
250 // Bare pointer -> Packed MemRef descriptor
251 addTargetMaterialization([&](OpBuilder &builder, Type resultType,
252 ValueRange inputs, Location loc,
253 Type originalType) -> Value {
254 // The original MemRef type is required to build a MemRef descriptor
255 // because the sizes/strides of the MemRef cannot be inferred from just the
256 // bare pointer.
257 if (!originalType)
258 return Value();
259 if (resultType != convertType(originalType))
260 return Value();
261 if (auto memrefType = dyn_cast<MemRefType>(originalType))
262 return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
263 if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
264 return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
265 *this);
266 return Value();
267 });
268
269 // Integer memory spaces map to themselves.
270 addTypeAttributeConversion(
271 [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
272}
273
274/// Returns the MLIR context.
276 return *getDialect()->getContext();
277}
278
280 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
281}
282
283unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
284 return options.dataLayout.getPointerSizeInBits(addressSpace);
285}
286
287Type LLVMTypeConverter::convertIndexType(IndexType type) const {
288 return getIndexType();
289}
290
291Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
292 return IntegerType::get(&getContext(), type.getWidth());
293}
294
295Type LLVMTypeConverter::convertFloatType(FloatType type) const {
296 // Valid LLVM float types are used directly.
297 if (LLVM::isCompatibleType(type))
298 return type;
299
300 // F4, F6, F8 types are converted to integer types with the same bit width.
301 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
302 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
303 Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
304 Float8E8M0FNUType>(type))
305 return IntegerType::get(&getContext(), type.getWidth());
306
307 // Other floating-point types: A custom type conversion rule must be
308 // specified by the user.
309 return Type();
310}
311
312// Convert a `ComplexType` to an LLVM type. The result is a complex number
313// struct with entries for the
314// 1. real part and for the
315// 2. imaginary part.
316Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
317 auto elementType = convertType(type.getElementType());
318 return LLVM::LLVMStructType::getLiteral(&getContext(),
319 {elementType, elementType});
320}
321
322// Except for signatures, MLIR function types are converted into LLVM
323// pointer-to-function types.
324Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
325 return LLVM::LLVMPointerType::get(type.getContext());
326}
327
328/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
329/// function arguments. Returns an empty container if none of these attributes
330/// are found in any of the arguments.
331static void
332filterByValRefArgAttrs(FunctionOpInterface funcOp,
333 SmallVectorImpl<std::optional<NamedAttribute>> &result) {
334 assert(result.empty() && "Unexpected non-empty output");
335 result.resize(funcOp.getNumArguments(), std::nullopt);
336 bool foundByValByRefAttrs = false;
337 for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
338 for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
339 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
340 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
341 foundByValByRefAttrs = true;
342 result[argIdx] = namedAttr;
343 break;
344 }
345 }
346 }
347
348 if (!foundByValByRefAttrs)
349 result.clear();
350}
351
352// Function types are converted to LLVM Function types by recursively converting
353// argument and result types. If MLIR Function has zero results, the LLVM
354// Function has one VoidType result. If MLIR Function has more than one result,
355// they are into an LLVM StructType in their order of appearance.
356// If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
357// `llvm.byref` function arguments which are not LLVM pointers are overridden
358// with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
359// converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
360Type LLVMTypeConverter::convertFunctionSignatureImpl(
361 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
362 LLVMTypeConverter::SignatureConversion &result,
363 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
364 // Select the argument converter depending on the calling convention.
365 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
366 auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
368
369 // Convert argument types one by one and check for errors.
370 for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
371 SmallVector<Type, 8> converted;
372 if (failed(funcArgConverter(*this, type, converted)))
373 return {};
374
375 // Rewrite converted type of `llvm.byval` or `llvm.byref` function
376 // argument that was not converted to an LLVM pointer types.
377 if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
378 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
379 // If the argument was already converted to an LLVM pointer type, we stop
380 // tracking it as it doesn't need more processing.
381 if (isa<LLVM::LLVMPointerType>(converted[0]))
382 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
383 else
384 converted[0] = LLVM::LLVMPointerType::get(&getContext());
385 }
386
387 result.addInputs(idx, converted);
388 }
389
390 // If function does not return anything, create the void result type,
391 // if it returns on element, convert it, otherwise pack the result types into
392 // a struct.
393 Type resultType =
394 funcTy.getNumResults() == 0
395 ? LLVM::LLVMVoidType::get(&getContext())
396 : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
397 if (!resultType)
398 return {};
399 return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
400 isVariadic);
401}
402
404 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
405 LLVMTypeConverter::SignatureConversion &result) const {
406 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
407 result,
408 /*byValRefNonPtrAttrs=*/nullptr);
409}
410
412 FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
413 LLVMTypeConverter::SignatureConversion &result,
414 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
415 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
416 // that were not converted to LLVM pointer types will be returned for further
417 // processing.
418 filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
419 auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
420 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
421 result, &byValRefNonPtrAttrs);
422}
423
424/// Converts the function type to a C-compatible format, in particular using
425/// pointers to memref descriptors for arguments.
426std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
429
430 Type resultType = type.getNumResults() == 0
431 ? LLVM::LLVMVoidType::get(&getContext())
432 : packFunctionResults(type.getResults());
433 if (!resultType)
434 return {};
435
436 auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
437 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
438 if (structType) {
439 // Struct types cannot be safely returned via C interface. Make this a
440 // pointer argument, instead.
441 inputs.push_back(ptrType);
442 resultType = LLVM::LLVMVoidType::get(&getContext());
443 }
444
445 for (Type t : type.getInputs()) {
446 auto converted = convertType(t);
447 if (!converted || !LLVM::isCompatibleType(converted))
448 return {};
449 if (isa<MemRefType, UnrankedMemRefType>(t))
450 converted = ptrType;
451 inputs.push_back(converted);
452 }
453
454 return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
455}
456
457/// Convert a memref type into a list of LLVM IR types that will form the
458/// memref descriptor. The result contains the following types:
459/// 1. The pointer to the allocated data buffer, followed by
460/// 2. The pointer to the aligned data buffer, followed by
461/// 3. A lowered `index`-type integer containing the distance between the
462/// beginning of the buffer and the first element to be accessed through the
463/// view, followed by
464/// 4. An array containing as many `index`-type integers as the rank of the
465/// MemRef: the array represents the size, in number of elements, of the memref
466/// along the given dimension. For constant MemRef dimensions, the
467/// corresponding size entry is a constant whose runtime value must match the
468/// static value, followed by
469/// 5. A second array containing as many `index`-type integers as the rank of
470/// the MemRef: the second array represents the "stride" (in tensor abstraction
471/// sense), i.e. the number of consecutive elements of the underlying buffer.
472/// TODO: add assertions for the static cases.
473///
474/// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
475/// are expanded into individual index-type elements.
476///
477/// template <typename Elem, typename Index, size_t Rank>
478/// struct {
479/// Elem *allocatedPtr;
480/// Elem *alignedPtr;
481/// Index offset;
482/// Index sizes[Rank]; // omitted when rank == 0
483/// Index strides[Rank]; // omitted when rank == 0
484/// };
487 bool unpackAggregates) const {
488 if (!type.isStrided()) {
489 emitError(
490 UnknownLoc::get(type.getContext()),
491 "conversion to strided form failed either due to non-strided layout "
492 "maps (which should have been normalized away) or other reasons");
493 return {};
494 }
495
496 Type elementType = convertType(type.getElementType());
497 if (!elementType)
498 return {};
499
500 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
501 if (failed(addressSpace)) {
502 emitError(UnknownLoc::get(type.getContext()),
503 "conversion of memref memory space ")
504 << type.getMemorySpace()
505 << " to integer address space "
506 "failed. Consider adding memory space conversions.";
507 return {};
508 }
509 auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
510
511 auto indexTy = getIndexType();
512
513 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
514 auto rank = type.getRank();
515 if (rank == 0)
516 return results;
517
518 if (unpackAggregates)
519 results.insert(results.end(), 2 * rank, indexTy);
520 else
521 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
522 return results;
523}
524
525unsigned
527 const DataLayout &layout) const {
528 // Compute the descriptor size given that of its components indicated above.
529 unsigned space = *getMemRefAddressSpace(type);
530 return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
531 (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
532}
533
534/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
535/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
536Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
537 // When converting a MemRefType to a struct with descriptor fields, do not
538 // unpack the `sizes` and `strides` arrays.
540 getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
541 if (types.empty())
542 return {};
543 return LLVM::LLVMStructType::getLiteral(&getContext(), types);
544}
545
546/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
547/// that will form the unranked memref descriptor. In particular, the fields
548/// for an unranked memref descriptor are:
549/// 1. index-typed rank, the dynamic rank of this MemRef
550/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
551/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
552/// be unranked.
555 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
556}
557
559 UnrankedMemRefType type, const DataLayout &layout) const {
560 // Compute the descriptor size given that of its components indicated above.
561 unsigned space = *getMemRefAddressSpace(type);
562 return layout.getTypeSize(getIndexType()) +
563 llvm::divideCeil(getPointerBitwidth(space), 8);
564}
565
566Type LLVMTypeConverter::convertUnrankedMemRefType(
567 UnrankedMemRefType type) const {
568 if (!convertType(type.getElementType()))
569 return {};
570 return LLVM::LLVMStructType::getLiteral(&getContext(),
572}
573
574FailureOr<unsigned>
576 if (!type.getMemorySpace()) // Default memory space -> 0.
577 return 0;
578 std::optional<Attribute> converted =
579 convertTypeAttribute(type, type.getMemorySpace());
580 if (!converted)
581 return failure();
582 if (!(*converted)) // Conversion to default is 0.
583 return 0;
584 if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
585 if (explicitSpace.getType().isIndex() ||
586 explicitSpace.getType().isSignlessInteger())
587 return explicitSpace.getInt();
588 }
589 return failure();
590}
591
592// Check if a memref type can be converted to a bare pointer.
594 if (isa<UnrankedMemRefType>(type))
595 // Unranked memref is not supported in the bare pointer calling convention.
596 return false;
597
598 // Check that the memref has static shape, strides and offset. Otherwise, it
599 // cannot be lowered to a bare pointer.
600 auto memrefTy = cast<MemRefType>(type);
601 if (!memrefTy.hasStaticShape())
602 return false;
603
604 int64_t offset = 0;
606 if (failed(memrefTy.getStridesAndOffset(strides, offset)))
607 return false;
608
609 for (int64_t stride : strides)
610 if (ShapedType::isDynamic(stride))
611 return false;
612
613 return ShapedType::isStatic(offset);
614}
615
616/// Convert a memref type to a bare pointer to the memref element type.
617Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
618 if (!canConvertToBarePtr(type))
619 return {};
620 Type elementType = convertType(type.getElementType());
621 if (!elementType)
622 return {};
623 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
624 if (failed(addressSpace))
625 return {};
626 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
627}
628
629/// Convert an n-D vector type to an LLVM vector type:
630/// * 0-D `vector<T>` are converted to vector<1xT>
631/// * 1-D `vector<axT>` remains as is while,
632/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
633/// `!llvm.array<ax...array<jxvector<kxT>>>`.
634/// As LLVM supports arrays of scalable vectors, this method will also convert
635/// n-D scalable vectors provided that only the trailing dim is scalable.
636FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
637 auto elementType = convertType(type.getElementType());
638 if (!elementType)
639 return {};
640 if (type.getShape().empty())
641 return VectorType::get({1}, elementType);
642 Type vectorType = VectorType::get(type.getShape().back(), elementType,
643 type.getScalableDims().back());
644 assert(LLVM::isCompatibleVectorType(vectorType) &&
645 "expected vector type compatible with the LLVM dialect");
646 // For n-D vector types for which a _non-trailing_ dim is scalable,
647 // return a failure. Supporting such cases would require LLVM
648 // to support something akin "scalable arrays" of vectors.
649 if (llvm::is_contained(type.getScalableDims().drop_back(), true))
650 return failure();
651 auto shape = type.getShape();
652 for (int i = shape.size() - 2; i >= 0; --i)
653 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
654 return vectorType;
655}
656
657/// Convert a type in the context of the default or bare pointer calling
658/// convention. Calling convention sensitive types, such as MemRefType and
659/// UnrankedMemRefType, are converted following the specific rules for the
660/// calling convention. Calling convention independent types are converted
661/// following the default LLVM type conversions.
663 Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
664 if (useBarePtrCallConv) {
665 if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
666 Type converted = convertMemRefToBarePtr(memrefTy);
667 if (!converted)
668 return failure();
669 result.push_back(converted);
670 return success();
671 }
672 }
673
674 return convertType(type, result);
675}
676
677/// Convert a non-empty list of types of values produced by an operation into an
678/// LLVM-compatible type. In particular, if more than one value is
679/// produced, create a literal structure with elements that correspond to each
680/// of the types converted with `convertType`.
682 assert(!types.empty() && "expected non-empty list of type");
683 if (types.size() == 1)
684 return convertType(types[0]);
685
686 SmallVector<Type> resultTypes;
687 resultTypes.reserve(types.size());
688 for (Type type : types) {
689 Type converted = convertType(type);
690 if (!converted || !LLVM::isCompatibleType(converted))
691 return {};
692 resultTypes.push_back(converted);
693 }
694
695 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
696}
697
698/// Convert a non-empty list of types to be returned from a function into an
699/// LLVM-compatible type. In particular, if more than one value is returned,
700/// create an LLVM dialect structure type with elements that correspond to each
701/// of the types converted with `convertCallingConventionType`.
703 TypeRange types, bool useBarePtrCallConv,
704 SmallVector<SmallVector<Type>> *groupedTypes,
705 int64_t *numConvertedTypes) const {
706 assert(!types.empty() && "expected non-empty list of type");
707 assert((!groupedTypes || groupedTypes->empty()) &&
708 "expected groupedTypes to be empty");
709
710 useBarePtrCallConv |= options.useBarePtrCallConv;
711 SmallVector<Type> resultTypes;
712 resultTypes.reserve(types.size());
713 size_t sizeBefore = 0;
714 for (auto t : types) {
715 if (failed(
716 convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
717 return {};
718 if (groupedTypes) {
719 SmallVector<Type> &group = groupedTypes->emplace_back();
720 llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
721 }
722 sizeBefore = resultTypes.size();
723 }
724
725 if (numConvertedTypes)
726 *numConvertedTypes = resultTypes.size();
727 if (resultTypes.size() == 1)
728 return resultTypes.front();
729 if (resultTypes.empty())
730 return {};
731 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
732}
733
735 OpBuilder &builder) const {
736 // Alloca with proper alignment. We do not expect optimizations of this
737 // alloca op and so we omit allocating at the entry block.
738 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
739 Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(),
740 builder.getIndexAttr(1));
741 Value allocated =
742 LLVM::AllocaOp::create(builder, loc, ptrType, operand.getType(), one);
743 // Store into the alloca'ed descriptor.
744 LLVM::StoreOp::create(builder, loc, operand, allocated);
745 return allocated;
746}
747
749 Location loc, ValueRange opOperands, ValueRange adaptorOperands,
750 OpBuilder &builder, bool useBarePtrCallConv) const {
752 for (size_t i = 0, e = adaptorOperands.size(); i < e; i++)
753 ranges.push_back(adaptorOperands.slice(i, 1));
754 return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv);
755}
756
758 Location loc, ValueRange opOperands, ArrayRef<ValueRange> adaptorOperands,
759 OpBuilder &builder, bool useBarePtrCallConv) const {
760 SmallVector<Value, 4> promotedOperands;
761 promotedOperands.reserve(adaptorOperands.size());
762 useBarePtrCallConv |= options.useBarePtrCallConv;
763 for (auto [operand, llvmOperand] :
764 llvm::zip_equal(opOperands, adaptorOperands)) {
765 if (useBarePtrCallConv) {
766 // For the bare-ptr calling convention, we only have to extract the
767 // aligned pointer of a memref.
768 if (isa<MemRefType>(operand.getType())) {
769 assert(llvmOperand.size() == 1 && "Expected a single operand");
770 MemRefDescriptor desc(llvmOperand.front());
771 promotedOperands.push_back(desc.alignedPtr(builder, loc));
772 continue;
773 } else if (isa<UnrankedMemRefType>(operand.getType())) {
774 llvm_unreachable("Unranked memrefs are not supported");
775 }
776 } else {
777 if (isa<UnrankedMemRefType>(operand.getType())) {
778 assert(llvmOperand.size() == 1 && "Expected a single operand");
779 UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(),
780 promotedOperands);
781 continue;
782 }
783 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
784 assert(llvmOperand.size() == 1 && "Expected a single operand");
785 MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType,
786 promotedOperands);
787 continue;
788 }
789 }
790
791 llvm::append_range(promotedOperands, llvmOperand);
792 }
793 return promotedOperands;
794}
795
796/// Callback to convert function argument types. It converts a MemRef function
797/// argument to a list of non-aggregate types containing descriptor
798/// information, and an UnrankedmemRef function argument to a list containing
799/// the rank and a pointer to a descriptor struct.
800LogicalResult
803 if (auto memref = dyn_cast<MemRefType>(type)) {
804 // In signatures, Memref descriptors are expanded into lists of
805 // non-aggregate values.
806 auto converted =
807 converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
808 if (converted.empty())
809 return failure();
810 result.append(converted.begin(), converted.end());
811 return success();
812 }
813 if (isa<UnrankedMemRefType>(type)) {
814 auto converted = converter.getUnrankedMemRefDescriptorFields();
815 if (converted.empty())
816 return failure();
817 result.append(converted.begin(), converted.end());
818 return success();
819 }
820 return converter.convertType(type, result);
821}
822
823/// Callback to convert function argument types. It converts MemRef function
824/// arguments to bare pointers to the MemRef element type.
825LogicalResult
828 return converter.convertCallingConventionType(
829 type, result,
830 /*useBarePointerCallConv=*/true);
831}
return success()
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute > > &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerType getI64Type()
Definition Builders.cpp:65
MLIRContext * getContext() const
Definition Builders.h:56
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ArrayRef< ValueRange > adaptorOperands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false, SmallVector< SmallVector< Type > > *groupedTypes=nullptr, int64_t *numConvertedTypes=nullptr) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
LogicalResult convertCallingConventionType(Type type, SmallVectorImpl< Type > &result, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:207
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.