MLIR 23.0.0git
LLVMTypes.cpp
Go to the documentation of this file.
1//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the types for the LLVM dialect in MLIR. These MLIR types
10// correspond to the LLVM IR type system.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TypeDetail.h"
15
21#include "mlir/IR/TypeSupport.h"
22
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/TypeSize.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::LLVM;
29
30constexpr const static uint64_t kBitsInByte = 8;
31
32//===----------------------------------------------------------------------===//
33// custom<FunctionTypes>
34//===----------------------------------------------------------------------===//
35
36static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37 bool &isVarArg) {
38 isVarArg = false;
39 // `(` `)`
40 if (succeeded(p.parseOptionalRParen()))
41 return success();
42
43 // `(` `...` `)`
44 if (succeeded(p.parseOptionalEllipsis())) {
45 isVarArg = true;
46 return p.parseRParen();
47 }
48
49 // type (`,` type)* (`,` `...`)?
50 Type type;
51 if (parsePrettyLLVMType(p, type))
52 return failure();
53 params.push_back(type);
54 while (succeeded(p.parseOptionalComma())) {
55 if (succeeded(p.parseOptionalEllipsis())) {
56 isVarArg = true;
57 return p.parseRParen();
58 }
59 if (parsePrettyLLVMType(p, type))
60 return failure();
61 params.push_back(type);
62 }
63 return p.parseRParen();
64}
65
67 bool isVarArg) {
68 llvm::interleaveComma(params, p,
69 [&](Type type) { printPrettyLLVMType(p, type); });
70 if (isVarArg) {
71 if (!params.empty())
72 p << ", ";
73 p << "...";
74 }
75 p << ')';
76}
77
78//===----------------------------------------------------------------------===//
79// custom<ExtTypeParams>
80//===----------------------------------------------------------------------===//
81
82/// Parses the parameter list for a target extension type. The parameter list
83/// contains an optional list of type parameters, followed by an optional list
84/// of integer parameters. Type and integer parameters cannot be interleaved in
85/// the list.
86/// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87/// typeList ::= type ("," type)*
88/// intList ::= integer ("," integer)*
89static ParseResult
92 bool parseType = true;
93 auto typeOrIntParser = [&]() -> ParseResult {
94 unsigned int i;
95 auto intResult = p.parseOptionalInteger(i);
96 if (intResult.has_value() && !failed(*intResult)) {
97 // Successfully parsed an integer.
98 intParams.push_back(i);
99 // After the first integer was successfully parsed, no
100 // more types can be parsed.
101 parseType = false;
102 return success();
103 }
104 if (parseType) {
105 Type t;
106 if (!parsePrettyLLVMType(p, t)) {
107 // Successfully parsed a type.
108 typeParams.push_back(t);
109 return success();
110 }
111 }
112 return failure();
113 };
114 if (p.parseCommaSeparatedList(typeOrIntParser)) {
116 "failed to parse parameter list for target extension type");
117 return failure();
118 }
119 return success();
120}
121
123 ArrayRef<unsigned int> intParams) {
124 p << typeParams;
125 if (!typeParams.empty() && !intParams.empty())
126 p << ", ";
127
128 p << intParams;
129}
130
131//===----------------------------------------------------------------------===//
132// ODS-Generated Definitions
133//===----------------------------------------------------------------------===//
134
135/// These are unused for now.
136/// TODO: Move over to these once more types have been migrated to TypeDef.
137[[maybe_unused]] static OptionalParseResult
138generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
140 AsmPrinter &printer);
141
142#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143
144#define GET_TYPEDEF_CLASSES
145#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146
147//===----------------------------------------------------------------------===//
148// LLVMArrayType
149//===----------------------------------------------------------------------===//
150
151bool LLVMArrayType::isValidElementType(Type type) {
152 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153 LLVMFunctionType, LLVMTokenType>(type);
154}
155
156LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
157 assert(elementType && "expected non-null subtype");
158 return Base::get(elementType.getContext(), elementType, numElements);
159}
160
161LLVMArrayType
162LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
163 Type elementType, uint64_t numElements) {
164 assert(elementType && "expected non-null subtype");
165 return Base::getChecked(emitError, elementType.getContext(), elementType,
166 numElements);
167}
168
169LogicalResult
170LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
171 Type elementType, uint64_t numElements) {
172 if (!isValidElementType(elementType))
173 return emitError() << "invalid array element type: " << elementType;
174 return success();
175}
176
177//===----------------------------------------------------------------------===//
178// DataLayoutTypeInterface
179//===----------------------------------------------------------------------===//
180
181llvm::TypeSize
182LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183 DataLayoutEntryListRef params) const {
184 return llvm::TypeSize::getFixed(kBitsInByte *
185 getTypeSize(dataLayout, params));
186}
187
188llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189 DataLayoutEntryListRef params) const {
190 return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191 dataLayout.getTypeABIAlignment(getElementType())) *
193}
194
195uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196 DataLayoutEntryListRef params) const {
197 return dataLayout.getTypeABIAlignment(getElementType());
198}
199
200uint64_t
201LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202 DataLayoutEntryListRef params) const {
203 return dataLayout.getTypePreferredAlignment(getElementType());
204}
205
206//===----------------------------------------------------------------------===//
207// Function type.
208//===----------------------------------------------------------------------===//
209
210bool LLVMFunctionType::isValidArgumentType(Type type) {
211 if (auto structType = dyn_cast<LLVMStructType>(type))
212 return !structType.isOpaque();
213
214 return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
215}
216
217bool LLVMFunctionType::isValidResultType(Type type) {
218 return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
219}
220
221LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
222 bool isVarArg) {
223 assert(result && "expected non-null result");
224 return Base::get(result.getContext(), result, arguments, isVarArg);
225}
226
227LLVMFunctionType
228LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
229 Type result, ArrayRef<Type> arguments,
230 bool isVarArg) {
231 assert(result && "expected non-null result");
232 return Base::getChecked(emitError, result.getContext(), result, arguments,
233 isVarArg);
234}
235
236LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
237 TypeRange results) const {
238 // LLVM functions have exactly one return type. An empty results range
239 // corresponds to a void return type (as FunctionOpInterface represents void
240 // functions with 0 results). More than one result is not valid.
241 if (results.size() > 1)
242 return {};
243 Type resultType =
244 results.empty() ? LLVMVoidType::get(getContext()) : results[0];
245 if (!isValidResultType(resultType))
246 return {};
247 if (!llvm::all_of(inputs, isValidArgumentType))
248 return {};
249 return get(resultType, llvm::to_vector(inputs), isVarArg());
250}
251
252ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
253 return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
254}
255
256LogicalResult
257LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
258 Type result, ArrayRef<Type> arguments, bool) {
259 if (!isValidResultType(result))
260 return emitError() << "invalid function result type: " << result;
261
262 for (Type arg : arguments)
263 if (!isValidArgumentType(arg))
264 return emitError() << "invalid function argument type: " << arg;
265
266 return success();
267}
268
269//===----------------------------------------------------------------------===//
270// DataLayoutTypeInterface
271//===----------------------------------------------------------------------===//
272
273constexpr const static uint64_t kDefaultPointerSizeBits = 64;
274constexpr const static uint64_t kDefaultPointerAlignment = 8;
275
277 PtrDLEntryPos pos) {
278 auto spec = cast<DenseIntElementsAttr>(attr);
279 auto idx = static_cast<int64_t>(pos);
280 if (idx >= spec.size())
281 return std::nullopt;
282 return spec.getValues<uint64_t>()[idx];
283}
284
285/// Returns the part of the data layout entry that corresponds to `pos` for the
286/// given `type` by interpreting the list of entries `params`. For the pointer
287/// type in the default address space, returns the default value if the entries
288/// do not provide a custom one, for other address spaces returns std::nullopt.
289static std::optional<uint64_t>
291 PtrDLEntryPos pos) {
292 // First, look for the entry for the pointer in the current address space.
293 Attribute currentEntry;
294 for (DataLayoutEntryInterface entry : params) {
295 if (!entry.isTypeEntry())
296 continue;
297 if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
298 type.getAddressSpace()) {
299 currentEntry = entry.getValue();
300 break;
301 }
302 }
303 if (currentEntry) {
304 std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
305 // If the optional `PtrDLEntryPos::Index` entry is not available, use the
306 // pointer size as the index bitwidth.
307 if (!value && pos == PtrDLEntryPos::Index)
308 value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
309 bool isSizeOrIndex =
311 return *value / (isSizeOrIndex ? 1 : kBitsInByte);
312 }
313
314 // If not found, and this is the pointer to the default memory space, assume
315 // 64-bit pointers.
316 if (type.getAddressSpace() == 0) {
317 bool isSizeOrIndex =
319 return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
320 }
321
322 return std::nullopt;
323}
324
325llvm::TypeSize
326LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
327 DataLayoutEntryListRef params) const {
328 if (std::optional<uint64_t> size =
330 return llvm::TypeSize::getFixed(*size);
331
332 // For other memory spaces, use the size of the pointer to the default memory
333 // space.
334 return dataLayout.getTypeSizeInBits(get(getContext()));
335}
336
337uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
338 DataLayoutEntryListRef params) const {
339 if (std::optional<uint64_t> alignment =
341 return *alignment;
342
343 return dataLayout.getTypeABIAlignment(get(getContext()));
344}
345
346uint64_t
347LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
348 DataLayoutEntryListRef params) const {
349 if (std::optional<uint64_t> alignment =
351 return *alignment;
352
353 return dataLayout.getTypePreferredAlignment(get(getContext()));
354}
355
356std::optional<uint64_t>
357LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
358 DataLayoutEntryListRef params) const {
359 if (std::optional<uint64_t> indexBitwidth =
361 return *indexBitwidth;
362
363 return dataLayout.getTypeIndexBitwidth(get(getContext()));
364}
365
366bool LLVMPointerType::areCompatible(
368 DataLayoutSpecInterface newSpec,
369 const DataLayoutIdentifiedEntryMap &map) const {
370 for (DataLayoutEntryInterface newEntry : newLayout) {
371 if (!newEntry.isTypeEntry())
372 continue;
373 uint64_t size = kDefaultPointerSizeBits;
374 uint64_t abi = kDefaultPointerAlignment;
375 auto newType =
376 llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
377 const auto *it =
378 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
379 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
380 return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
381 newType.getAddressSpace();
382 }
383 return false;
384 });
385 if (it == oldLayout.end()) {
386 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
387 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
388 return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
389 }
390 return false;
391 });
392 }
393 if (it != oldLayout.end()) {
396 }
397
398 Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
399 uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
400 uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
401 if (size != newSize || abi < newAbi || abi % newAbi != 0)
402 return false;
403 }
404 return true;
405}
406
407LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
408 Location loc) const {
409 for (DataLayoutEntryInterface entry : entries) {
410 if (!entry.isTypeEntry())
411 continue;
412 auto key = llvm::cast<Type>(entry.getKey());
413 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
414 if (!values || (values.size() != 3 && values.size() != 4)) {
415 return emitError(loc)
416 << "expected layout attribute for " << key
417 << " to be a dense integer elements attribute with 3 or 4 "
418 "elements";
419 }
420 if (!values.getElementType().isInteger(64))
421 return emitError(loc) << "expected i64 parameters for " << key;
422
425 return emitError(loc) << "preferred alignment is expected to be at least "
426 "as large as ABI alignment";
427 }
428 }
429 return success();
430}
431
432//===----------------------------------------------------------------------===//
433// Struct type.
434//===----------------------------------------------------------------------===//
435
436bool LLVMStructType::isValidElementType(Type type) {
437 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
438 LLVMFunctionType, LLVMTokenType>(type);
439}
440
441LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
442 StringRef name) {
443 return Base::get(context, name, /*opaque=*/false);
444}
445
446LLVMStructType LLVMStructType::getIdentifiedChecked(
447 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
448 StringRef name) {
449 return Base::getChecked(emitError, context, name, /*opaque=*/false);
450}
451
452LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
453 StringRef name,
454 ArrayRef<Type> elements,
455 bool isPacked) {
456 std::string stringName = name.str();
457 unsigned counter = 0;
458 do {
459 auto type = LLVMStructType::getIdentified(context, stringName);
460 if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
461 counter += 1;
462 stringName = (Twine(name) + "." + std::to_string(counter)).str();
463 continue;
464 }
465 return type;
466 } while (true);
467}
468
469LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
470 ArrayRef<Type> types, bool isPacked) {
471 return Base::get(context, types, isPacked);
472}
473
474LLVMStructType
475LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
476 MLIRContext *context, ArrayRef<Type> types,
477 bool isPacked) {
478 return Base::getChecked(emitError, context, types, isPacked);
479}
480
481LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
482 return Base::get(context, name, /*opaque=*/true);
483}
484
485LLVMStructType
486LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
487 MLIRContext *context, StringRef name) {
488 return Base::getChecked(emitError, context, name, /*opaque=*/true);
489}
490
491LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
492 assert(isIdentified() && "can only set bodies of identified structs");
493 assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
494 "expected valid body types");
495 return Base::mutate(types, isPacked);
496}
497
498bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
499bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
500bool LLVMStructType::isOpaque() const {
501 return getImpl()->isIdentified() &&
502 (getImpl()->isOpaque() || !getImpl()->isInitialized());
503}
504bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
505StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
506ArrayRef<Type> LLVMStructType::getBody() const {
507 return isIdentified() ? getImpl()->getIdentifiedStructBody()
508 : getImpl()->getTypeList();
509}
510
511LogicalResult
512LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
513 bool) {
514 return success();
515}
516
517LogicalResult
518LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
519 ArrayRef<Type> types, bool) {
520 for (Type t : types)
521 if (!isValidElementType(t))
522 return emitError() << "invalid LLVM structure element type: " << t;
523
524 return success();
525}
526
527llvm::TypeSize
528LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
529 DataLayoutEntryListRef params) const {
530 auto structSize = llvm::TypeSize::getFixed(0);
531 uint64_t structAlignment = 1;
532 for (Type element : getBody()) {
533 uint64_t elementAlignment =
534 isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
535 // Add padding to the struct size to align it to the abi alignment of the
536 // element type before than adding the size of the element.
537 structSize = llvm::alignTo(structSize, elementAlignment);
538 structSize += dataLayout.getTypeSize(element);
539
540 // The alignment requirement of a struct is equal to the strictest alignment
541 // requirement of its elements.
542 structAlignment = std::max(elementAlignment, structAlignment);
543 }
544 // At the end, add padding to the struct to satisfy its own alignment
545 // requirement. Otherwise structs inside of arrays would be misaligned.
546 structSize = llvm::alignTo(structSize, structAlignment);
547 return structSize * kBitsInByte;
548}
549
550namespace {
551enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
552} // namespace
553
554static std::optional<uint64_t>
556 StructDLEntryPos pos) {
557 const auto *currentEntry =
558 llvm::find_if(params, [](DataLayoutEntryInterface entry) {
559 return entry.isTypeEntry();
560 });
561 if (currentEntry == params.end())
562 return std::nullopt;
563
564 auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
565 if (pos == StructDLEntryPos::Preferred &&
566 attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
567 // If no preferred was specified, fall back to abi alignment
568 pos = StructDLEntryPos::Abi;
569
570 return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
571}
572
573static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
575 LLVMStructType type,
576 StructDLEntryPos pos) {
577 // Packed structs always have an abi alignment of 1
578 if (pos == StructDLEntryPos::Abi && type.isPacked()) {
579 return 1;
580 }
581
582 // The alignment requirement of a struct is equal to the strictest alignment
583 // requirement of its elements.
584 uint64_t structAlignment = 1;
585 for (Type iter : type.getBody()) {
586 structAlignment =
587 std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
588 }
589
590 // Entries are only allowed to be stricter than the required alignment
591 if (std::optional<uint64_t> entryResult =
592 getStructDataLayoutEntry(params, type, pos))
593 return std::max(*entryResult / kBitsInByte, structAlignment);
594
595 return structAlignment;
596}
597
598uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
599 DataLayoutEntryListRef params) const {
600 return calculateStructAlignment(dataLayout, params, *this,
601 StructDLEntryPos::Abi);
602}
603
604uint64_t
605LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
606 DataLayoutEntryListRef params) const {
607 return calculateStructAlignment(dataLayout, params, *this,
608 StructDLEntryPos::Preferred);
609}
610
611static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
612 return llvm::cast<DenseIntElementsAttr>(attr)
613 .getValues<uint64_t>()[static_cast<size_t>(pos)];
614}
615
616bool LLVMStructType::areCompatible(
618 DataLayoutSpecInterface newSpec,
619 const DataLayoutIdentifiedEntryMap &map) const {
620 for (DataLayoutEntryInterface newEntry : newLayout) {
621 if (!newEntry.isTypeEntry())
622 continue;
623
624 const auto *previousEntry =
625 llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
626 return entry.isTypeEntry();
627 });
628 if (previousEntry == oldLayout.end())
629 continue;
630
631 uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
632 StructDLEntryPos::Abi);
633 uint64_t newAbi =
634 extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
635 if (abi < newAbi || abi % newAbi != 0)
636 return false;
637 }
638 return true;
639}
640
641LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
642 Location loc) const {
643 for (DataLayoutEntryInterface entry : entries) {
644 if (!entry.isTypeEntry())
645 continue;
646
647 auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
648 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
649 if (!values || (values.size() != 2 && values.size() != 1)) {
650 return emitError(loc)
651 << "expected layout attribute for "
652 << llvm::cast<Type>(entry.getKey())
653 << " to be a dense integer elements attribute of 1 or 2 elements";
654 }
655 if (!values.getElementType().isInteger(64))
656 return emitError(loc) << "expected i64 entries for " << key;
657
658 if (key.isIdentified() || !key.getBody().empty()) {
659 return emitError(loc) << "unexpected layout attribute for struct " << key;
660 }
661
662 if (values.size() == 1)
663 continue;
664
665 if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
666 extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
667 return emitError(loc) << "preferred alignment is expected to be at least "
668 "as large as ABI alignment";
669 }
670 }
671 return mlir::success();
672}
673
674//===----------------------------------------------------------------------===//
675// LLVMTargetExtType.
676//===----------------------------------------------------------------------===//
677
678static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
679static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
680static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier";
681
682bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
683 // See llvm/lib/IR/Type.cpp for reference.
684 uint64_t properties = 0;
685
686 if (getExtTypeName().starts_with(kSpirvPrefix))
687 properties |=
688 (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
689
690 if (getExtTypeName() == kAMDGCNNamedBarrier)
691 properties |= LLVMTargetExtType::CanBeGlobal;
692
693 return (properties & prop) == prop;
694}
695
696bool LLVM::LLVMTargetExtType::supportsMemOps() const {
697 // See llvm/lib/IR/Type.cpp for reference.
698 if (getExtTypeName().starts_with(kSpirvPrefix))
699 return true;
700
701 if (getExtTypeName() == kArmSVCount)
702 return true;
703
704 return false;
705}
706
707//===----------------------------------------------------------------------===//
708// LLVMPPCFP128Type
709//===----------------------------------------------------------------------===//
710
711const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
712 return APFloat::PPCDoubleDouble();
713}
714
715//===----------------------------------------------------------------------===//
716// Utility functions.
717//===----------------------------------------------------------------------===//
718
719/// Check whether type is a compatible ptr type. These are pointer-like types
720/// with no element type, no metadata, and using the LLVM
721/// LLVMAddrSpaceAttrInterface memory space.
722static bool isCompatiblePtrType(Type type) {
723 auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
724 if (!ptrTy)
725 return false;
726 return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
727 isa<LLVMAddrSpaceAttrInterface>(ptrTy.getMemorySpace());
728}
729
731 // clang-format off
732 if (llvm::isa<
733 BFloat16Type,
734 Float16Type,
735 Float32Type,
736 Float64Type,
737 Float80Type,
738 Float128Type,
739 LLVMArrayType,
740 LLVMFunctionType,
741 LLVMLabelType,
742 LLVMMetadataType,
743 LLVMPPCFP128Type,
744 LLVMPointerType,
745 LLVMStructType,
746 LLVMTokenType,
747 LLVMTargetExtType,
748 LLVMVoidType,
749 LLVMX86AMXType
750 >(type)) {
751 // clang-format on
752 return true;
753 }
754
755 // Only signless integers are compatible.
756 if (auto intType = llvm::dyn_cast<IntegerType>(type))
757 return intType.isSignless();
758
759 // 1D vector types are compatible.
760 if (auto vecType = llvm::dyn_cast<VectorType>(type))
761 return vecType.getRank() == 1;
762
763 return isCompatiblePtrType(type);
764}
765
766static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
767 if (!compatibleTypes.insert(type).second)
768 return true;
769
770 auto isCompatible = [&](Type type) {
771 return isCompatibleImpl(type, compatibleTypes);
772 };
773
774 bool result =
776 .Case([&](LLVMStructType structType) {
777 return llvm::all_of(structType.getBody(), isCompatible);
778 })
779 .Case([&](LLVMFunctionType funcType) {
780 return isCompatible(funcType.getReturnType()) &&
781 llvm::all_of(funcType.getParams(), isCompatible);
782 })
783 .Case([](IntegerType intType) { return intType.isSignless(); })
784 .Case([&](VectorType vecType) {
785 return vecType.getRank() == 1 &&
786 isCompatible(vecType.getElementType());
787 })
788 .Case([&](LLVMPointerType pointerType) { return true; })
789 .Case([&](LLVMTargetExtType extType) {
790 return llvm::all_of(extType.getTypeParams(), isCompatible);
791 })
792 // clang-format off
793 .Case([&](LLVMArrayType containerType) {
794 return isCompatible(containerType.getElementType());
795 })
796 .Case<
797 BFloat16Type,
798 Float16Type,
799 Float32Type,
800 Float64Type,
801 Float80Type,
802 Float128Type,
803 LLVMLabelType,
804 LLVMMetadataType,
805 LLVMPPCFP128Type,
806 LLVMTokenType,
807 LLVMVoidType,
808 LLVMX86AMXType
809 >([](Type) { return true; })
810 // clang-format on
811 .Case<PtrLikeTypeInterface>(
812 [](Type type) { return isCompatiblePtrType(type); })
813 .Default(false);
814
815 if (!result)
816 compatibleTypes.erase(type);
817
818 return result;
819}
820
821bool LLVMDialect::isCompatibleType(Type type) {
822 if (auto *llvmDialect =
823 type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
824 return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
825
826 DenseSet<Type> localCompatibleTypes;
827 return isCompatibleImpl(type, localCompatibleTypes);
828}
829
831 return LLVMDialect::isCompatibleType(type);
832}
833
835 return /*LLVM_PrimitiveType*/ (
837 !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
838 /*LLVM_OpaqueStruct*/
839 !(isa<LLVM::LLVMStructType>(type) &&
840 cast<LLVM::LLVMStructType>(type).isOpaque()) &&
841 /*LLVM_AnyTargetExt*/
842 !(isa<LLVM::LLVMTargetExtType>(type) &&
843 !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
844}
845
847 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
848 Float80Type, Float128Type, LLVMPPCFP128Type>(type);
849}
850
852 if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
853 if (vecType.getRank() != 1)
854 return false;
855 Type elementType = vecType.getElementType();
856 if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
857 return intType.isSignless();
858 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
859 Float80Type, Float128Type, LLVMPointerType>(elementType) ||
860 isCompatiblePtrType(elementType);
861 }
862 return false;
863}
864
865llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
866 auto vecTy = dyn_cast<VectorType>(type);
867 assert(vecTy && "incompatible with LLVM vector type");
868 if (vecTy.isScalable())
869 return llvm::ElementCount::getScalable(vecTy.getNumElements());
870 return llvm::ElementCount::getFixed(vecTy.getNumElements());
871}
872
874 assert(llvm::isa<VectorType>(vectorType) &&
875 "expected LLVM-compatible vector type");
876 return llvm::cast<VectorType>(vectorType).isScalable();
877}
878
879Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
880 bool isScalable) {
881 assert(VectorType::isValidElementType(elementType) &&
882 "incompatible element type");
883 return VectorType::get(numElements, elementType, {isScalable});
884}
885
887 const llvm::ElementCount &numElements) {
888 if (numElements.isScalable())
889 return getVectorType(elementType, numElements.getKnownMinValue(),
890 /*isScalable=*/true);
891 return getVectorType(elementType, numElements.getFixedValue(),
892 /*isScalable=*/false);
893}
894
896 assert(isCompatibleType(type) &&
897 "expected a type compatible with the LLVM dialect");
898
900 .Case<BFloat16Type, Float16Type>(
901 [](Type) { return llvm::TypeSize::getFixed(16); })
902 .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
903 .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
904 .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
905 .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
906 .Case([](IntegerType intTy) {
907 return llvm::TypeSize::getFixed(intTy.getWidth());
908 })
909 .Case<LLVMPPCFP128Type>(
910 [](Type) { return llvm::TypeSize::getFixed(128); })
911 .Case([](VectorType t) {
912 assert(isCompatibleVectorType(t) &&
913 "unexpected incompatible with LLVM vector type");
914 llvm::TypeSize elementSize =
915 getPrimitiveTypeSizeInBits(t.getElementType());
916 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
917 elementSize.isScalable());
918 })
919 .Default([](Type ty) {
920 assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
921 LLVMTokenType, LLVMStructType, LLVMArrayType,
922 LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
923 ty)) &&
924 "unexpected missing support for primitive type");
925 return llvm::TypeSize::getFixed(0);
926 });
927}
928
929//===----------------------------------------------------------------------===//
930// LLVMDialect
931//===----------------------------------------------------------------------===//
932
933void LLVMDialect::registerTypes() {
934 addTypes<
935#define GET_TYPEDEF_LIST
936#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
937 >();
938}
939
940Type LLVMDialect::parseType(DialectAsmParser &parser) const {
941 return detail::parseType(parser);
942}
943
944void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
945 return detail::printType(type, os);
946}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
constexpr static const uint64_t kDefaultPointerAlignment
static void printExtTypeParams(AsmPrinter &p, ArrayRef< Type > typeParams, ArrayRef< unsigned int > intParams)
constexpr static const uint64_t kDefaultPointerSizeBits
static void printFunctionTypes(AsmPrinter &p, ArrayRef< Type > params, bool isVarArg)
Definition LLVMTypes.cpp:66
static bool isCompatibleImpl(Type type, DenseSet< Type > &compatibleTypes)
static ParseResult parseExtTypeParams(AsmParser &p, SmallVectorImpl< Type > &typeParams, SmallVectorImpl< unsigned int > &intParams)
Parses the parameter list for a target extension type.
Definition LLVMTypes.cpp:90
static constexpr llvm::StringRef kSpirvPrefix
static std::optional< uint64_t > getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, PtrDLEntryPos pos)
Returns the part of the data layout entry that corresponds to pos for the given type by interpreting ...
static bool isCompatiblePtrType(Type type)
Check whether type is a compatible ptr type.
static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
These are unused for now.
static ParseResult parseFunctionTypes(AsmParser &p, SmallVector< Type > &params, bool &isVarArg)
Definition LLVMTypes.cpp:36
constexpr static const uint64_t kBitsInByte
Definition LLVMTypes.cpp:30
static constexpr llvm::StringRef kArmSVCount
static constexpr llvm::StringRef kAMDGCNNamedBarrier
static LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
static std::optional< uint64_t > getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos)
static uint64_t calculateStructAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
b getContext())
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition Attributes.h:25
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This class represents a diagnostic that is inflight and set to be reported.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
llvm::TypeSize getPrimitiveTypeSizeInBits(Type type)
Returns the size of the given primitive LLVM dialect-compatible type (including vectors) in bits,...
void printPrettyLLVMType(AsmPrinter &p, Type type)
Print any MLIR type or a concise syntax for LLVM types.
bool isLoadableType(Type type)
Returns true if the given type is a loadable type compatible with the LLVM dialect.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
ParseResult parsePrettyLLVMType(AsmParser &p, Type &type)
Parse any MLIR type or a concise syntax for LLVM types.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
PtrDLEntryPos
The positions of different values in the data layout entry for pointers.
Definition LLVMTypes.h:146
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
::llvm::MapVector<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface > DataLayoutIdentifiedEntryMap
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::ArrayRef< DataLayoutEntryInterface > DataLayoutEntryListRef
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144