MLIR 23.0.0git
LLVMTypes.cpp
Go to the documentation of this file.
1//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the types for the LLVM dialect in MLIR. These MLIR types
10// correspond to the LLVM IR type system.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TypeDetail.h"
15
21#include "mlir/IR/TypeSupport.h"
22
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/TypeSize.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::LLVM;
29
30constexpr const static uint64_t kBitsInByte = 8;
31
32//===----------------------------------------------------------------------===//
33// custom<FunctionTypes>
34//===----------------------------------------------------------------------===//
35
36static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37 bool &isVarArg) {
38 isVarArg = false;
39 // `(` `)`
40 if (succeeded(p.parseOptionalRParen()))
41 return success();
42
43 // `(` `...` `)`
44 if (succeeded(p.parseOptionalEllipsis())) {
45 isVarArg = true;
46 return p.parseRParen();
47 }
48
49 // type (`,` type)* (`,` `...`)?
50 Type type;
51 if (parsePrettyLLVMType(p, type))
52 return failure();
53 params.push_back(type);
54 while (succeeded(p.parseOptionalComma())) {
55 if (succeeded(p.parseOptionalEllipsis())) {
56 isVarArg = true;
57 return p.parseRParen();
58 }
59 if (parsePrettyLLVMType(p, type))
60 return failure();
61 params.push_back(type);
62 }
63 return p.parseRParen();
64}
65
67 bool isVarArg) {
68 llvm::interleaveComma(params, p,
69 [&](Type type) { printPrettyLLVMType(p, type); });
70 if (isVarArg) {
71 if (!params.empty())
72 p << ", ";
73 p << "...";
74 }
75 p << ')';
76}
77
78//===----------------------------------------------------------------------===//
79// custom<ExtTypeParams>
80//===----------------------------------------------------------------------===//
81
82/// Parses the parameter list for a target extension type. The parameter list
83/// contains an optional list of type parameters, followed by an optional list
84/// of integer parameters. Type and integer parameters cannot be interleaved in
85/// the list.
86/// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87/// typeList ::= type ("," type)*
88/// intList ::= integer ("," integer)*
89static ParseResult
92 bool parseType = true;
93 auto typeOrIntParser = [&]() -> ParseResult {
94 unsigned int i;
95 auto intResult = p.parseOptionalInteger(i);
96 if (intResult.has_value() && !failed(*intResult)) {
97 // Successfully parsed an integer.
98 intParams.push_back(i);
99 // After the first integer was successfully parsed, no
100 // more types can be parsed.
101 parseType = false;
102 return success();
103 }
104 if (parseType) {
105 Type t;
106 if (!parsePrettyLLVMType(p, t)) {
107 // Successfully parsed a type.
108 typeParams.push_back(t);
109 return success();
110 }
111 }
112 return failure();
113 };
114 if (p.parseCommaSeparatedList(typeOrIntParser)) {
116 "failed to parse parameter list for target extension type");
117 return failure();
118 }
119 return success();
120}
121
123 ArrayRef<unsigned int> intParams) {
124 p << typeParams;
125 if (!typeParams.empty() && !intParams.empty())
126 p << ", ";
127
128 p << intParams;
129}
130
131//===----------------------------------------------------------------------===//
132// ODS-Generated Definitions
133//===----------------------------------------------------------------------===//
134
135/// These are unused for now.
136/// TODO: Move over to these once more types have been migrated to TypeDef.
137[[maybe_unused]] static OptionalParseResult
138generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
140 AsmPrinter &printer);
141
142#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143
144#define GET_TYPEDEF_CLASSES
145#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146
147//===----------------------------------------------------------------------===//
148// LLVMArrayType
149//===----------------------------------------------------------------------===//
150
151bool LLVMArrayType::isValidElementType(Type type) {
152 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153 LLVMFunctionType, TokenType>(type);
154}
155
156LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
157 assert(elementType && "expected non-null subtype");
158 return Base::get(elementType.getContext(), elementType, numElements);
159}
160
161LLVMArrayType
162LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
163 Type elementType, uint64_t numElements) {
164 assert(elementType && "expected non-null subtype");
165 return Base::getChecked(emitError, elementType.getContext(), elementType,
166 numElements);
167}
168
169LogicalResult
170LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
171 Type elementType, uint64_t numElements) {
172 if (!isValidElementType(elementType))
173 return emitError() << "invalid array element type: " << elementType;
174 return success();
175}
176
177//===----------------------------------------------------------------------===//
178// DataLayoutTypeInterface
179//===----------------------------------------------------------------------===//
180
181llvm::TypeSize
182LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183 DataLayoutEntryListRef params) const {
184 return llvm::TypeSize::getFixed(kBitsInByte *
185 getTypeSize(dataLayout, params));
186}
187
188llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189 DataLayoutEntryListRef params) const {
190 return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191 dataLayout.getTypeABIAlignment(getElementType())) *
193}
194
195uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196 DataLayoutEntryListRef params) const {
197 return dataLayout.getTypeABIAlignment(getElementType());
198}
199
200uint64_t
201LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202 DataLayoutEntryListRef params) const {
203 return dataLayout.getTypePreferredAlignment(getElementType());
204}
205
206//===----------------------------------------------------------------------===//
207// LLVMByteType
208//===----------------------------------------------------------------------===//
209
210llvm::TypeSize
211LLVMByteType::getTypeSizeInBits(const DataLayout &dataLayout,
212 DataLayoutEntryListRef params) const {
213 return llvm::TypeSize::getFixed(getBitWidth());
214}
215
216uint64_t LLVMByteType::getABIAlignment(const DataLayout &dataLayout,
217 DataLayoutEntryListRef params) const {
218 return llvm::PowerOf2Ceil(llvm::divideCeil(getBitWidth(), kBitsInByte));
219}
220
221LogicalResult LLVMByteType::verify(function_ref<InFlightDiagnostic()> emitError,
222 unsigned bitWidth) {
223 if (bitWidth == 0)
224 return emitError() << "bitwidth must be greater than 0";
225
226 // Mirror LLVM IR, which limits the bit width to fit in 23 bits.
227 constexpr unsigned kMaxBitWidth = 1 << 23;
228 if (bitWidth >= kMaxBitWidth)
229 return emitError() << "bitwidth must be less than " << kMaxBitWidth
230 << ", but got " << bitWidth;
231 return success();
232}
233
234//===----------------------------------------------------------------------===//
235// Function type.
236//===----------------------------------------------------------------------===//
237
238bool LLVMFunctionType::isValidArgumentType(Type type) {
239 if (auto structType = dyn_cast<LLVMStructType>(type))
240 return !structType.isOpaque();
241
242 return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
243}
244
245bool LLVMFunctionType::isValidResultType(Type type) {
246 return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
247}
248
249LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
250 bool isVarArg) {
251 assert(result && "expected non-null result");
252 return Base::get(result.getContext(), result, arguments, isVarArg);
253}
254
255LLVMFunctionType
256LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
257 Type result, ArrayRef<Type> arguments,
258 bool isVarArg) {
259 assert(result && "expected non-null result");
260 return Base::getChecked(emitError, result.getContext(), result, arguments,
261 isVarArg);
262}
263
264LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
265 TypeRange results) const {
266 // LLVM functions have exactly one return type. An empty results range
267 // corresponds to a void return type (as FunctionOpInterface represents void
268 // functions with 0 results). More than one result is not valid.
269 if (results.size() > 1)
270 return {};
271 Type resultType =
272 results.empty() ? LLVMVoidType::get(getContext()) : results[0];
273 if (!isValidResultType(resultType))
274 return {};
275 if (!llvm::all_of(inputs, isValidArgumentType))
276 return {};
277 return get(resultType, llvm::to_vector(inputs), isVarArg());
278}
279
280ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
281 return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
282}
283
284LogicalResult
285LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
286 Type result, ArrayRef<Type> arguments, bool) {
287 if (!isValidResultType(result))
288 return emitError() << "invalid function result type: " << result;
289
290 for (Type arg : arguments)
291 if (!isValidArgumentType(arg))
292 return emitError() << "invalid function argument type: " << arg;
293
294 return success();
295}
296
297//===----------------------------------------------------------------------===//
298// DataLayoutTypeInterface
299//===----------------------------------------------------------------------===//
300
301constexpr const static uint64_t kDefaultPointerSizeBits = 64;
302constexpr const static uint64_t kDefaultPointerAlignment = 8;
303
305 PtrDLEntryPos pos) {
306 auto spec = cast<DenseIntElementsAttr>(attr);
307 auto idx = static_cast<int64_t>(pos);
308 if (idx >= spec.size())
309 return std::nullopt;
310 return spec.getValues<uint64_t>()[idx];
311}
312
313/// Returns the part of the data layout entry that corresponds to `pos` for the
314/// given `type` by interpreting the list of entries `params`. For the pointer
315/// type in the default address space, returns the default value if the entries
316/// do not provide a custom one, for other address spaces returns std::nullopt.
317static std::optional<uint64_t>
319 PtrDLEntryPos pos) {
320 // First, look for the entry for the pointer in the current address space.
321 Attribute currentEntry;
322 for (DataLayoutEntryInterface entry : params) {
323 if (!entry.isTypeEntry())
324 continue;
325 if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
326 type.getAddressSpace()) {
327 currentEntry = entry.getValue();
328 break;
329 }
330 }
331 if (currentEntry) {
332 std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
333 // If the optional `PtrDLEntryPos::Index` entry is not available, use the
334 // pointer size as the index bitwidth.
335 if (!value && pos == PtrDLEntryPos::Index)
336 value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
337 bool isSizeOrIndex =
339 return *value / (isSizeOrIndex ? 1 : kBitsInByte);
340 }
341
342 // If not found, and this is the pointer to the default memory space, assume
343 // 64-bit pointers.
344 if (type.getAddressSpace() == 0) {
345 bool isSizeOrIndex =
347 return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
348 }
349
350 return std::nullopt;
351}
352
353llvm::TypeSize
354LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
355 DataLayoutEntryListRef params) const {
356 if (std::optional<uint64_t> size =
358 return llvm::TypeSize::getFixed(*size);
359
360 // For other memory spaces, use the size of the pointer to the default memory
361 // space.
362 return dataLayout.getTypeSizeInBits(get(getContext()));
363}
364
365uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
366 DataLayoutEntryListRef params) const {
367 if (std::optional<uint64_t> alignment =
369 return *alignment;
370
371 return dataLayout.getTypeABIAlignment(get(getContext()));
372}
373
374uint64_t
375LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
376 DataLayoutEntryListRef params) const {
377 if (std::optional<uint64_t> alignment =
379 return *alignment;
380
381 return dataLayout.getTypePreferredAlignment(get(getContext()));
382}
383
384std::optional<uint64_t>
385LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
386 DataLayoutEntryListRef params) const {
387 if (std::optional<uint64_t> indexBitwidth =
389 return *indexBitwidth;
390
391 return dataLayout.getTypeIndexBitwidth(get(getContext()));
392}
393
394bool LLVMPointerType::areCompatible(
396 DataLayoutSpecInterface newSpec,
397 const DataLayoutIdentifiedEntryMap &map) const {
398 for (DataLayoutEntryInterface newEntry : newLayout) {
399 if (!newEntry.isTypeEntry())
400 continue;
401 uint64_t size = kDefaultPointerSizeBits;
402 uint64_t abi = kDefaultPointerAlignment;
403 auto newType =
404 llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
405 const auto *it =
406 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
407 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
408 return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
409 newType.getAddressSpace();
410 }
411 return false;
412 });
413 if (it == oldLayout.end()) {
414 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
415 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
416 return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
417 }
418 return false;
419 });
420 }
421 if (it != oldLayout.end()) {
424 }
425
426 Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
427 uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
428 uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
429 if (size != newSize || abi < newAbi || abi % newAbi != 0)
430 return false;
431 }
432 return true;
433}
434
435LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
436 Location loc) const {
437 for (DataLayoutEntryInterface entry : entries) {
438 if (!entry.isTypeEntry())
439 continue;
440 auto key = llvm::cast<Type>(entry.getKey());
441 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
442 if (!values || (values.size() != 3 && values.size() != 4)) {
443 return emitError(loc)
444 << "expected layout attribute for " << key
445 << " to be a dense integer elements attribute with 3 or 4 "
446 "elements";
447 }
448 if (!values.getElementType().isInteger(64))
449 return emitError(loc) << "expected i64 parameters for " << key;
450
453 return emitError(loc) << "preferred alignment is expected to be at least "
454 "as large as ABI alignment";
455 }
456 }
457 return success();
458}
459
460//===----------------------------------------------------------------------===//
461// Struct type.
462//===----------------------------------------------------------------------===//
463
464bool LLVMStructType::isValidElementType(Type type) {
465 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
466 LLVMFunctionType, TokenType>(type);
467}
468
469LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
470 StringRef name) {
471 return Base::get(context, name, /*opaque=*/false);
472}
473
474LLVMStructType LLVMStructType::getIdentifiedChecked(
475 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
476 StringRef name) {
477 return Base::getChecked(emitError, context, name, /*opaque=*/false);
478}
479
480LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
481 StringRef name,
482 ArrayRef<Type> elements,
483 bool isPacked) {
484 std::string stringName = name.str();
485 unsigned counter = 0;
486 do {
487 auto type = LLVMStructType::getIdentified(context, stringName);
488 if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
489 counter += 1;
490 stringName = (Twine(name) + "." + std::to_string(counter)).str();
491 continue;
492 }
493 return type;
494 } while (true);
495}
496
497LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
498 ArrayRef<Type> types, bool isPacked) {
499 return Base::get(context, types, isPacked);
500}
501
502LLVMStructType
503LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
504 MLIRContext *context, ArrayRef<Type> types,
505 bool isPacked) {
506 return Base::getChecked(emitError, context, types, isPacked);
507}
508
509LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
510 return Base::get(context, name, /*opaque=*/true);
511}
512
513LLVMStructType
514LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
515 MLIRContext *context, StringRef name) {
516 return Base::getChecked(emitError, context, name, /*opaque=*/true);
517}
518
519LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
520 assert(isIdentified() && "can only set bodies of identified structs");
521 assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
522 "expected valid body types");
523 return Base::mutate(types, isPacked);
524}
525
526bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
527bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
528bool LLVMStructType::isOpaque() const {
529 return getImpl()->isIdentified() &&
530 (getImpl()->isOpaque() || !getImpl()->isInitialized());
531}
532bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
533StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
534ArrayRef<Type> LLVMStructType::getBody() const {
535 return isIdentified() ? getImpl()->getIdentifiedStructBody()
536 : getImpl()->getTypeList();
537}
538
539LogicalResult
540LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
541 bool) {
542 return success();
543}
544
545LogicalResult
546LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
547 ArrayRef<Type> types, bool) {
548 for (Type t : types)
549 if (!isValidElementType(t))
550 return emitError() << "invalid LLVM structure element type: " << t;
551
552 return success();
553}
554
555llvm::TypeSize
556LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
557 DataLayoutEntryListRef params) const {
558 auto structSize = llvm::TypeSize::getFixed(0);
559 uint64_t structAlignment = 1;
560 for (Type element : getBody()) {
561 uint64_t elementAlignment =
562 isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
563 // Add padding to the struct size to align it to the abi alignment of the
564 // element type before than adding the size of the element.
565 structSize = llvm::alignTo(structSize, elementAlignment);
566 structSize += dataLayout.getTypeSize(element);
567
568 // The alignment requirement of a struct is equal to the strictest alignment
569 // requirement of its elements.
570 structAlignment = std::max(elementAlignment, structAlignment);
571 }
572 // At the end, add padding to the struct to satisfy its own alignment
573 // requirement. Otherwise structs inside of arrays would be misaligned.
574 structSize = llvm::alignTo(structSize, structAlignment);
575 return structSize * kBitsInByte;
576}
577
578namespace {
579enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
580} // namespace
581
582static std::optional<uint64_t>
584 StructDLEntryPos pos) {
585 const auto *currentEntry =
586 llvm::find_if(params, [](DataLayoutEntryInterface entry) {
587 return entry.isTypeEntry();
588 });
589 if (currentEntry == params.end())
590 return std::nullopt;
591
592 auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
593 if (pos == StructDLEntryPos::Preferred &&
594 attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
595 // If no preferred was specified, fall back to abi alignment
596 pos = StructDLEntryPos::Abi;
597
598 return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
599}
600
601static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
603 LLVMStructType type,
604 StructDLEntryPos pos) {
605 // Packed structs always have an abi alignment of 1
606 if (pos == StructDLEntryPos::Abi && type.isPacked()) {
607 return 1;
608 }
609
610 // The alignment requirement of a struct is equal to the strictest alignment
611 // requirement of its elements.
612 uint64_t structAlignment = 1;
613 for (Type iter : type.getBody()) {
614 structAlignment =
615 std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
616 }
617
618 // Entries are only allowed to be stricter than the required alignment
619 if (std::optional<uint64_t> entryResult =
620 getStructDataLayoutEntry(params, type, pos))
621 return std::max(*entryResult / kBitsInByte, structAlignment);
622
623 return structAlignment;
624}
625
626uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
627 DataLayoutEntryListRef params) const {
628 return calculateStructAlignment(dataLayout, params, *this,
629 StructDLEntryPos::Abi);
630}
631
632uint64_t
633LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
634 DataLayoutEntryListRef params) const {
635 return calculateStructAlignment(dataLayout, params, *this,
636 StructDLEntryPos::Preferred);
637}
638
639static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
640 return llvm::cast<DenseIntElementsAttr>(attr)
641 .getValues<uint64_t>()[static_cast<size_t>(pos)];
642}
643
644bool LLVMStructType::areCompatible(
646 DataLayoutSpecInterface newSpec,
647 const DataLayoutIdentifiedEntryMap &map) const {
648 for (DataLayoutEntryInterface newEntry : newLayout) {
649 if (!newEntry.isTypeEntry())
650 continue;
651
652 const auto *previousEntry =
653 llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
654 return entry.isTypeEntry();
655 });
656 if (previousEntry == oldLayout.end())
657 continue;
658
659 uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
660 StructDLEntryPos::Abi);
661 uint64_t newAbi =
662 extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
663 if (abi < newAbi || abi % newAbi != 0)
664 return false;
665 }
666 return true;
667}
668
669LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
670 Location loc) const {
671 for (DataLayoutEntryInterface entry : entries) {
672 if (!entry.isTypeEntry())
673 continue;
674
675 auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
676 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
677 if (!values || (values.size() != 2 && values.size() != 1)) {
678 return emitError(loc)
679 << "expected layout attribute for "
680 << llvm::cast<Type>(entry.getKey())
681 << " to be a dense integer elements attribute of 1 or 2 elements";
682 }
683 if (!values.getElementType().isInteger(64))
684 return emitError(loc) << "expected i64 entries for " << key;
685
686 if (key.isIdentified() || !key.getBody().empty()) {
687 return emitError(loc) << "unexpected layout attribute for struct " << key;
688 }
689
690 if (values.size() == 1)
691 continue;
692
693 if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
694 extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
695 return emitError(loc) << "preferred alignment is expected to be at least "
696 "as large as ABI alignment";
697 }
698 }
699 return mlir::success();
700}
701
702//===----------------------------------------------------------------------===//
703// LLVMTargetExtType.
704//===----------------------------------------------------------------------===//
705
706static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
707static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
708static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier";
709
710bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
711 // See llvm/lib/IR/Type.cpp for reference.
712 uint64_t properties = 0;
713
714 if (getExtTypeName().starts_with(kSpirvPrefix))
715 properties |=
716 (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
717
718 if (getExtTypeName() == kAMDGCNNamedBarrier)
719 properties |= LLVMTargetExtType::CanBeGlobal;
720
721 return (properties & prop) == prop;
722}
723
724bool LLVM::LLVMTargetExtType::supportsMemOps() const {
725 // See llvm/lib/IR/Type.cpp for reference.
726 if (getExtTypeName().starts_with(kSpirvPrefix))
727 return true;
728
729 if (getExtTypeName() == kArmSVCount)
730 return true;
731
732 return false;
733}
734
735//===----------------------------------------------------------------------===//
736// LLVMPPCFP128Type
737//===----------------------------------------------------------------------===//
738
739const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
740 return APFloat::PPCDoubleDouble();
741}
742
743//===----------------------------------------------------------------------===//
744// Utility functions.
745//===----------------------------------------------------------------------===//
746
747/// Check whether type is a compatible ptr type. These are pointer-like types
748/// with no element type, no metadata, and using the LLVM
749/// LLVMAddrSpaceAttrInterface memory space.
750static bool isCompatiblePtrType(Type type) {
751 auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
752 if (!ptrTy)
753 return false;
754 return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
755 isa<LLVMAddrSpaceAttrInterface>(ptrTy.getMemorySpace());
756}
757
759 // clang-format off
760 if (llvm::isa<
761 BFloat16Type,
762 Float16Type,
763 Float32Type,
764 Float64Type,
765 Float80Type,
766 Float128Type,
767 LLVMArrayType,
768 LLVMByteType,
769 LLVMFunctionType,
770 LLVMLabelType,
771 LLVMMetadataType,
772 LLVMPPCFP128Type,
773 LLVMPointerType,
774 LLVMStructType,
775 LLVMTargetExtType,
776 LLVMVoidType,
777 LLVMX86AMXType,
778 TokenType
779 >(type)) {
780 // clang-format on
781 return true;
782 }
783
784 // Only signless integers are compatible.
785 if (auto intType = llvm::dyn_cast<IntegerType>(type))
786 return intType.isSignless();
787
788 // 1D vector types are compatible.
789 if (auto vecType = llvm::dyn_cast<VectorType>(type))
790 return vecType.getRank() == 1;
791
792 return isCompatiblePtrType(type);
793}
794
795static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
796 if (!compatibleTypes.insert(type).second)
797 return true;
798
799 auto isCompatible = [&](Type type) {
800 return isCompatibleImpl(type, compatibleTypes);
801 };
802
803 bool result =
805 .Case([&](LLVMStructType structType) {
806 return llvm::all_of(structType.getBody(), isCompatible);
807 })
808 .Case([&](LLVMFunctionType funcType) {
809 return isCompatible(funcType.getReturnType()) &&
810 llvm::all_of(funcType.getParams(), isCompatible);
811 })
812 .Case([](IntegerType intType) { return intType.isSignless(); })
813 .Case([&](VectorType vecType) {
814 return vecType.getRank() == 1 &&
815 isCompatible(vecType.getElementType());
816 })
817 .Case([&](LLVMPointerType pointerType) { return true; })
818 .Case([&](LLVMTargetExtType extType) {
819 return llvm::all_of(extType.getTypeParams(), isCompatible);
820 })
821 // clang-format off
822 .Case([&](LLVMArrayType containerType) {
823 return isCompatible(containerType.getElementType());
824 })
825 .Case<
826 BFloat16Type,
827 Float16Type,
828 Float32Type,
829 Float64Type,
830 Float80Type,
831 Float128Type,
832 LLVMByteType,
833 LLVMLabelType,
834 LLVMMetadataType,
835 LLVMPPCFP128Type,
836 LLVMVoidType,
837 LLVMX86AMXType,
838 TokenType
839 >([](Type) { return true; })
840 // clang-format on
841 .Case<PtrLikeTypeInterface>(
842 [](Type type) { return isCompatiblePtrType(type); })
843 .Default(false);
844
845 if (!result)
846 compatibleTypes.erase(type);
847
848 return result;
849}
850
851bool LLVMDialect::isCompatibleType(Type type) {
852 if (auto *llvmDialect =
853 type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
854 return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
855
856 DenseSet<Type> localCompatibleTypes;
857 return isCompatibleImpl(type, localCompatibleTypes);
858}
859
861 return LLVMDialect::isCompatibleType(type);
862}
863
865 return /*LLVM_PrimitiveType*/ (
867 !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
868 /*LLVM_OpaqueStruct*/
869 !(isa<LLVM::LLVMStructType>(type) &&
870 cast<LLVM::LLVMStructType>(type).isOpaque()) &&
871 /*LLVM_AnyTargetExt*/
872 !(isa<LLVM::LLVMTargetExtType>(type) &&
873 !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
874}
875
877 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
878 Float80Type, Float128Type, LLVMPPCFP128Type>(type);
879}
880
882 if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
883 if (vecType.getRank() != 1)
884 return false;
885 Type elementType = vecType.getElementType();
886 if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
887 return intType.isSignless();
888 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
889 Float80Type, Float128Type, LLVMByteType, LLVMPointerType>(
890 elementType) ||
891 isCompatiblePtrType(elementType);
892 }
893 return false;
894}
895
896llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
897 auto vecTy = dyn_cast<VectorType>(type);
898 assert(vecTy && "incompatible with LLVM vector type");
899 if (vecTy.isScalable())
900 return llvm::ElementCount::getScalable(vecTy.getNumElements());
901 return llvm::ElementCount::getFixed(vecTy.getNumElements());
902}
903
905 assert(llvm::isa<VectorType>(vectorType) &&
906 "expected LLVM-compatible vector type");
907 return llvm::cast<VectorType>(vectorType).isScalable();
908}
909
910Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
911 bool isScalable) {
912 assert(VectorType::isValidElementType(elementType) &&
913 "incompatible element type");
914 return VectorType::get(numElements, elementType, {isScalable});
915}
916
918 const llvm::ElementCount &numElements) {
919 if (numElements.isScalable())
920 return getVectorType(elementType, numElements.getKnownMinValue(),
921 /*isScalable=*/true);
922 return getVectorType(elementType, numElements.getFixedValue(),
923 /*isScalable=*/false);
924}
925
927 assert(isCompatibleType(type) &&
928 "expected a type compatible with the LLVM dialect");
929
931 .Case<BFloat16Type, Float16Type>(
932 [](Type) { return llvm::TypeSize::getFixed(16); })
933 .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
934 .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
935 .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
936 .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
937 .Case([](IntegerType intTy) {
938 return llvm::TypeSize::getFixed(intTy.getWidth());
939 })
940 .Case([](LLVMByteType byteTy) {
941 return llvm::TypeSize::getFixed(byteTy.getBitWidth());
942 })
943 .Case<LLVMPPCFP128Type>(
944 [](Type) { return llvm::TypeSize::getFixed(128); })
945 .Case([](VectorType t) {
946 assert(isCompatibleVectorType(t) &&
947 "unexpected incompatible with LLVM vector type");
948 llvm::TypeSize elementSize =
949 getPrimitiveTypeSizeInBits(t.getElementType());
950 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
951 elementSize.isScalable());
952 })
953 .Default([](Type ty) {
954 assert(
955 (llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, TokenType,
956 LLVMStructType, LLVMArrayType, LLVMPointerType,
957 LLVMFunctionType, LLVMTargetExtType>(ty)) &&
958 "unexpected missing support for primitive type");
959 return llvm::TypeSize::getFixed(0);
960 });
961}
962
963//===----------------------------------------------------------------------===//
964// LLVMDialect
965//===----------------------------------------------------------------------===//
966
967void LLVMDialect::registerTypes() {
968 addTypes<
969#define GET_TYPEDEF_LIST
970#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
971
972 >();
973}
974
975Type LLVMDialect::parseType(DialectAsmParser &parser) const {
976 return detail::parseType(parser);
977}
978
979void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
980 return detail::printType(type, os);
981}
return success()
static unsigned getBitWidth(Type type)
Definition Pattern.cpp:390
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
constexpr static const uint64_t kDefaultPointerAlignment
static void printExtTypeParams(AsmPrinter &p, ArrayRef< Type > typeParams, ArrayRef< unsigned int > intParams)
constexpr static const uint64_t kDefaultPointerSizeBits
static void printFunctionTypes(AsmPrinter &p, ArrayRef< Type > params, bool isVarArg)
Definition LLVMTypes.cpp:66
static bool isCompatibleImpl(Type type, DenseSet< Type > &compatibleTypes)
static ParseResult parseExtTypeParams(AsmParser &p, SmallVectorImpl< Type > &typeParams, SmallVectorImpl< unsigned int > &intParams)
Parses the parameter list for a target extension type.
Definition LLVMTypes.cpp:90
static constexpr llvm::StringRef kSpirvPrefix
static std::optional< uint64_t > getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, PtrDLEntryPos pos)
Returns the part of the data layout entry that corresponds to pos for the given type by interpreting ...
static bool isCompatiblePtrType(Type type)
Check whether type is a compatible ptr type.
static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
These are unused for now.
static ParseResult parseFunctionTypes(AsmParser &p, SmallVector< Type > &params, bool &isVarArg)
Definition LLVMTypes.cpp:36
constexpr static const uint64_t kBitsInByte
Definition LLVMTypes.cpp:30
static constexpr llvm::StringRef kArmSVCount
static constexpr llvm::StringRef kAMDGCNNamedBarrier
static LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
static std::optional< uint64_t > getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos)
static uint64_t calculateStructAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
b getContext())
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition Attributes.h:25
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This class represents a diagnostic that is inflight and set to be reported.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
llvm::TypeSize getPrimitiveTypeSizeInBits(Type type)
Returns the size of the given primitive LLVM dialect-compatible type (including vectors) in bits,...
void printPrettyLLVMType(AsmPrinter &p, Type type)
Print any MLIR type or a concise syntax for LLVM types.
bool isLoadableType(Type type)
Returns true if the given type is a loadable type compatible with the LLVM dialect.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
ParseResult parsePrettyLLVMType(AsmParser &p, Type &type)
Parse any MLIR type or a concise syntax for LLVM types.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
PtrDLEntryPos
The positions of different values in the data layout entry for pointers.
Definition LLVMTypes.h:145
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
::llvm::MapVector<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface > DataLayoutIdentifiedEntryMap
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::ArrayRef< DataLayoutEntryInterface > DataLayoutEntryListRef
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147