MLIR 22.0.0git
LLVMTypes.cpp
Go to the documentation of this file.
1//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the types for the LLVM dialect in MLIR. These MLIR types
10// correspond to the LLVM IR type system.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TypeDetail.h"
15
21#include "mlir/IR/TypeSupport.h"
22
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/TypeSize.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::LLVM;
29
30constexpr const static uint64_t kBitsInByte = 8;
31
32//===----------------------------------------------------------------------===//
33// custom<FunctionTypes>
34//===----------------------------------------------------------------------===//
35
36static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37 bool &isVarArg) {
38 isVarArg = false;
39 // `(` `)`
40 if (succeeded(p.parseOptionalRParen()))
41 return success();
42
43 // `(` `...` `)`
44 if (succeeded(p.parseOptionalEllipsis())) {
45 isVarArg = true;
46 return p.parseRParen();
47 }
48
49 // type (`,` type)* (`,` `...`)?
50 Type type;
51 if (parsePrettyLLVMType(p, type))
52 return failure();
53 params.push_back(type);
54 while (succeeded(p.parseOptionalComma())) {
55 if (succeeded(p.parseOptionalEllipsis())) {
56 isVarArg = true;
57 return p.parseRParen();
58 }
59 if (parsePrettyLLVMType(p, type))
60 return failure();
61 params.push_back(type);
62 }
63 return p.parseRParen();
64}
65
67 bool isVarArg) {
68 llvm::interleaveComma(params, p,
69 [&](Type type) { printPrettyLLVMType(p, type); });
70 if (isVarArg) {
71 if (!params.empty())
72 p << ", ";
73 p << "...";
74 }
75 p << ')';
76}
77
78//===----------------------------------------------------------------------===//
79// custom<ExtTypeParams>
80//===----------------------------------------------------------------------===//
81
82/// Parses the parameter list for a target extension type. The parameter list
83/// contains an optional list of type parameters, followed by an optional list
84/// of integer parameters. Type and integer parameters cannot be interleaved in
85/// the list.
86/// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87/// typeList ::= type ("," type)*
88/// intList ::= integer ("," integer)*
89static ParseResult
92 bool parseType = true;
93 auto typeOrIntParser = [&]() -> ParseResult {
94 unsigned int i;
95 auto intResult = p.parseOptionalInteger(i);
96 if (intResult.has_value() && !failed(*intResult)) {
97 // Successfully parsed an integer.
98 intParams.push_back(i);
99 // After the first integer was successfully parsed, no
100 // more types can be parsed.
101 parseType = false;
102 return success();
103 }
104 if (parseType) {
105 Type t;
106 if (!parsePrettyLLVMType(p, t)) {
107 // Successfully parsed a type.
108 typeParams.push_back(t);
109 return success();
110 }
111 }
112 return failure();
113 };
114 if (p.parseCommaSeparatedList(typeOrIntParser)) {
116 "failed to parse parameter list for target extension type");
117 return failure();
118 }
119 return success();
120}
121
123 ArrayRef<unsigned int> intParams) {
124 p << typeParams;
125 if (!typeParams.empty() && !intParams.empty())
126 p << ", ";
127
128 p << intParams;
129}
130
131//===----------------------------------------------------------------------===//
132// ODS-Generated Definitions
133//===----------------------------------------------------------------------===//
134
135/// These are unused for now.
136/// TODO: Move over to these once more types have been migrated to TypeDef.
137[[maybe_unused]] static OptionalParseResult
138generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
140 AsmPrinter &printer);
141
142#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143
144#define GET_TYPEDEF_CLASSES
145#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146
147//===----------------------------------------------------------------------===//
148// LLVMArrayType
149//===----------------------------------------------------------------------===//
150
151bool LLVMArrayType::isValidElementType(Type type) {
152 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153 LLVMFunctionType, LLVMTokenType>(type);
154}
155
156LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
157 assert(elementType && "expected non-null subtype");
158 return Base::get(elementType.getContext(), elementType, numElements);
159}
160
161LLVMArrayType
162LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
163 Type elementType, uint64_t numElements) {
164 assert(elementType && "expected non-null subtype");
165 return Base::getChecked(emitError, elementType.getContext(), elementType,
166 numElements);
167}
168
169LogicalResult
170LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
171 Type elementType, uint64_t numElements) {
172 if (!isValidElementType(elementType))
173 return emitError() << "invalid array element type: " << elementType;
174 return success();
175}
176
177//===----------------------------------------------------------------------===//
178// DataLayoutTypeInterface
179//===----------------------------------------------------------------------===//
180
181llvm::TypeSize
182LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183 DataLayoutEntryListRef params) const {
184 return llvm::TypeSize::getFixed(kBitsInByte *
185 getTypeSize(dataLayout, params));
186}
187
188llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189 DataLayoutEntryListRef params) const {
190 return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191 dataLayout.getTypeABIAlignment(getElementType())) *
193}
194
195uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196 DataLayoutEntryListRef params) const {
197 return dataLayout.getTypeABIAlignment(getElementType());
198}
199
200uint64_t
201LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202 DataLayoutEntryListRef params) const {
203 return dataLayout.getTypePreferredAlignment(getElementType());
204}
205
206//===----------------------------------------------------------------------===//
207// Function type.
208//===----------------------------------------------------------------------===//
209
210bool LLVMFunctionType::isValidArgumentType(Type type) {
211 return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
212}
213
214bool LLVMFunctionType::isValidResultType(Type type) {
215 return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
216}
217
218LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
219 bool isVarArg) {
220 assert(result && "expected non-null result");
221 return Base::get(result.getContext(), result, arguments, isVarArg);
222}
223
224LLVMFunctionType
225LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226 Type result, ArrayRef<Type> arguments,
227 bool isVarArg) {
228 assert(result && "expected non-null result");
229 return Base::getChecked(emitError, result.getContext(), result, arguments,
230 isVarArg);
231}
232
233LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234 TypeRange results) const {
235 if (results.size() != 1 || !isValidResultType(results[0]))
236 return {};
237 if (!llvm::all_of(inputs, isValidArgumentType))
238 return {};
239 return get(results[0], llvm::to_vector(inputs), isVarArg());
240}
241
242ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
243 return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
244}
245
246LogicalResult
247LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
248 Type result, ArrayRef<Type> arguments, bool) {
249 if (!isValidResultType(result))
250 return emitError() << "invalid function result type: " << result;
251
252 for (Type arg : arguments)
253 if (!isValidArgumentType(arg))
254 return emitError() << "invalid function argument type: " << arg;
255
256 return success();
257}
258
259//===----------------------------------------------------------------------===//
260// DataLayoutTypeInterface
261//===----------------------------------------------------------------------===//
262
263constexpr const static uint64_t kDefaultPointerSizeBits = 64;
264constexpr const static uint64_t kDefaultPointerAlignment = 8;
265
267 PtrDLEntryPos pos) {
268 auto spec = cast<DenseIntElementsAttr>(attr);
269 auto idx = static_cast<int64_t>(pos);
270 if (idx >= spec.size())
271 return std::nullopt;
272 return spec.getValues<uint64_t>()[idx];
273}
274
275/// Returns the part of the data layout entry that corresponds to `pos` for the
276/// given `type` by interpreting the list of entries `params`. For the pointer
277/// type in the default address space, returns the default value if the entries
278/// do not provide a custom one, for other address spaces returns std::nullopt.
279static std::optional<uint64_t>
281 PtrDLEntryPos pos) {
282 // First, look for the entry for the pointer in the current address space.
283 Attribute currentEntry;
284 for (DataLayoutEntryInterface entry : params) {
285 if (!entry.isTypeEntry())
286 continue;
287 if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
288 type.getAddressSpace()) {
289 currentEntry = entry.getValue();
290 break;
291 }
292 }
293 if (currentEntry) {
294 std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
295 // If the optional `PtrDLEntryPos::Index` entry is not available, use the
296 // pointer size as the index bitwidth.
297 if (!value && pos == PtrDLEntryPos::Index)
298 value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
299 bool isSizeOrIndex =
301 return *value / (isSizeOrIndex ? 1 : kBitsInByte);
302 }
303
304 // If not found, and this is the pointer to the default memory space, assume
305 // 64-bit pointers.
306 if (type.getAddressSpace() == 0) {
307 bool isSizeOrIndex =
309 return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
310 }
311
312 return std::nullopt;
313}
314
315llvm::TypeSize
316LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
317 DataLayoutEntryListRef params) const {
318 if (std::optional<uint64_t> size =
320 return llvm::TypeSize::getFixed(*size);
321
322 // For other memory spaces, use the size of the pointer to the default memory
323 // space.
324 return dataLayout.getTypeSizeInBits(get(getContext()));
325}
326
327uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
328 DataLayoutEntryListRef params) const {
329 if (std::optional<uint64_t> alignment =
331 return *alignment;
332
333 return dataLayout.getTypeABIAlignment(get(getContext()));
334}
335
336uint64_t
337LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
338 DataLayoutEntryListRef params) const {
339 if (std::optional<uint64_t> alignment =
341 return *alignment;
342
343 return dataLayout.getTypePreferredAlignment(get(getContext()));
344}
345
346std::optional<uint64_t>
347LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
348 DataLayoutEntryListRef params) const {
349 if (std::optional<uint64_t> indexBitwidth =
351 return *indexBitwidth;
352
353 return dataLayout.getTypeIndexBitwidth(get(getContext()));
354}
355
356bool LLVMPointerType::areCompatible(
358 DataLayoutSpecInterface newSpec,
359 const DataLayoutIdentifiedEntryMap &map) const {
360 for (DataLayoutEntryInterface newEntry : newLayout) {
361 if (!newEntry.isTypeEntry())
362 continue;
363 uint64_t size = kDefaultPointerSizeBits;
364 uint64_t abi = kDefaultPointerAlignment;
365 auto newType =
366 llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
367 const auto *it =
368 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
369 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
370 return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
371 newType.getAddressSpace();
372 }
373 return false;
374 });
375 if (it == oldLayout.end()) {
376 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
377 if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
378 return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
379 }
380 return false;
381 });
382 }
383 if (it != oldLayout.end()) {
386 }
387
388 Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
389 uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
390 uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
391 if (size != newSize || abi < newAbi || abi % newAbi != 0)
392 return false;
393 }
394 return true;
395}
396
397LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
398 Location loc) const {
399 for (DataLayoutEntryInterface entry : entries) {
400 if (!entry.isTypeEntry())
401 continue;
402 auto key = llvm::cast<Type>(entry.getKey());
403 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
404 if (!values || (values.size() != 3 && values.size() != 4)) {
405 return emitError(loc)
406 << "expected layout attribute for " << key
407 << " to be a dense integer elements attribute with 3 or 4 "
408 "elements";
409 }
410 if (!values.getElementType().isInteger(64))
411 return emitError(loc) << "expected i64 parameters for " << key;
412
415 return emitError(loc) << "preferred alignment is expected to be at least "
416 "as large as ABI alignment";
417 }
418 }
419 return success();
420}
421
422//===----------------------------------------------------------------------===//
423// Struct type.
424//===----------------------------------------------------------------------===//
425
426bool LLVMStructType::isValidElementType(Type type) {
427 return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
428 LLVMFunctionType, LLVMTokenType>(type);
429}
430
431LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
432 StringRef name) {
433 return Base::get(context, name, /*opaque=*/false);
434}
435
436LLVMStructType LLVMStructType::getIdentifiedChecked(
437 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
438 StringRef name) {
439 return Base::getChecked(emitError, context, name, /*opaque=*/false);
440}
441
442LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
443 StringRef name,
444 ArrayRef<Type> elements,
445 bool isPacked) {
446 std::string stringName = name.str();
447 unsigned counter = 0;
448 do {
449 auto type = LLVMStructType::getIdentified(context, stringName);
450 if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
451 counter += 1;
452 stringName = (Twine(name) + "." + std::to_string(counter)).str();
453 continue;
454 }
455 return type;
456 } while (true);
457}
458
459LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
460 ArrayRef<Type> types, bool isPacked) {
461 return Base::get(context, types, isPacked);
462}
463
464LLVMStructType
465LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
466 MLIRContext *context, ArrayRef<Type> types,
467 bool isPacked) {
468 return Base::getChecked(emitError, context, types, isPacked);
469}
470
471LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
472 return Base::get(context, name, /*opaque=*/true);
473}
474
475LLVMStructType
476LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
477 MLIRContext *context, StringRef name) {
478 return Base::getChecked(emitError, context, name, /*opaque=*/true);
479}
480
481LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
482 assert(isIdentified() && "can only set bodies of identified structs");
483 assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
484 "expected valid body types");
485 return Base::mutate(types, isPacked);
486}
487
488bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
489bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
490bool LLVMStructType::isOpaque() const {
491 return getImpl()->isIdentified() &&
492 (getImpl()->isOpaque() || !getImpl()->isInitialized());
493}
494bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
495StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
496ArrayRef<Type> LLVMStructType::getBody() const {
497 return isIdentified() ? getImpl()->getIdentifiedStructBody()
498 : getImpl()->getTypeList();
499}
500
501LogicalResult
502LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
503 bool) {
504 return success();
505}
506
507LogicalResult
508LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
509 ArrayRef<Type> types, bool) {
510 for (Type t : types)
511 if (!isValidElementType(t))
512 return emitError() << "invalid LLVM structure element type: " << t;
513
514 return success();
515}
516
517llvm::TypeSize
518LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
519 DataLayoutEntryListRef params) const {
520 auto structSize = llvm::TypeSize::getFixed(0);
521 uint64_t structAlignment = 1;
522 for (Type element : getBody()) {
523 uint64_t elementAlignment =
524 isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
525 // Add padding to the struct size to align it to the abi alignment of the
526 // element type before than adding the size of the element.
527 structSize = llvm::alignTo(structSize, elementAlignment);
528 structSize += dataLayout.getTypeSize(element);
529
530 // The alignment requirement of a struct is equal to the strictest alignment
531 // requirement of its elements.
532 structAlignment = std::max(elementAlignment, structAlignment);
533 }
534 // At the end, add padding to the struct to satisfy its own alignment
535 // requirement. Otherwise structs inside of arrays would be misaligned.
536 structSize = llvm::alignTo(structSize, structAlignment);
537 return structSize * kBitsInByte;
538}
539
540namespace {
541enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
542} // namespace
543
544static std::optional<uint64_t>
546 StructDLEntryPos pos) {
547 const auto *currentEntry =
548 llvm::find_if(params, [](DataLayoutEntryInterface entry) {
549 return entry.isTypeEntry();
550 });
551 if (currentEntry == params.end())
552 return std::nullopt;
553
554 auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
555 if (pos == StructDLEntryPos::Preferred &&
556 attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
557 // If no preferred was specified, fall back to abi alignment
558 pos = StructDLEntryPos::Abi;
559
560 return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
561}
562
563static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
565 LLVMStructType type,
566 StructDLEntryPos pos) {
567 // Packed structs always have an abi alignment of 1
568 if (pos == StructDLEntryPos::Abi && type.isPacked()) {
569 return 1;
570 }
571
572 // The alignment requirement of a struct is equal to the strictest alignment
573 // requirement of its elements.
574 uint64_t structAlignment = 1;
575 for (Type iter : type.getBody()) {
576 structAlignment =
577 std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
578 }
579
580 // Entries are only allowed to be stricter than the required alignment
581 if (std::optional<uint64_t> entryResult =
582 getStructDataLayoutEntry(params, type, pos))
583 return std::max(*entryResult / kBitsInByte, structAlignment);
584
585 return structAlignment;
586}
587
588uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
589 DataLayoutEntryListRef params) const {
590 return calculateStructAlignment(dataLayout, params, *this,
591 StructDLEntryPos::Abi);
592}
593
594uint64_t
595LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
596 DataLayoutEntryListRef params) const {
597 return calculateStructAlignment(dataLayout, params, *this,
598 StructDLEntryPos::Preferred);
599}
600
601static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
602 return llvm::cast<DenseIntElementsAttr>(attr)
603 .getValues<uint64_t>()[static_cast<size_t>(pos)];
604}
605
606bool LLVMStructType::areCompatible(
608 DataLayoutSpecInterface newSpec,
609 const DataLayoutIdentifiedEntryMap &map) const {
610 for (DataLayoutEntryInterface newEntry : newLayout) {
611 if (!newEntry.isTypeEntry())
612 continue;
613
614 const auto *previousEntry =
615 llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
616 return entry.isTypeEntry();
617 });
618 if (previousEntry == oldLayout.end())
619 continue;
620
621 uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
622 StructDLEntryPos::Abi);
623 uint64_t newAbi =
624 extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
625 if (abi < newAbi || abi % newAbi != 0)
626 return false;
627 }
628 return true;
629}
630
631LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
632 Location loc) const {
633 for (DataLayoutEntryInterface entry : entries) {
634 if (!entry.isTypeEntry())
635 continue;
636
637 auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
638 auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
639 if (!values || (values.size() != 2 && values.size() != 1)) {
640 return emitError(loc)
641 << "expected layout attribute for "
642 << llvm::cast<Type>(entry.getKey())
643 << " to be a dense integer elements attribute of 1 or 2 elements";
644 }
645 if (!values.getElementType().isInteger(64))
646 return emitError(loc) << "expected i64 entries for " << key;
647
648 if (key.isIdentified() || !key.getBody().empty()) {
649 return emitError(loc) << "unexpected layout attribute for struct " << key;
650 }
651
652 if (values.size() == 1)
653 continue;
654
655 if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
656 extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
657 return emitError(loc) << "preferred alignment is expected to be at least "
658 "as large as ABI alignment";
659 }
660 }
661 return mlir::success();
662}
663
664//===----------------------------------------------------------------------===//
665// LLVMTargetExtType.
666//===----------------------------------------------------------------------===//
667
668static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
669static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
670static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier";
671
672bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
673 // See llvm/lib/IR/Type.cpp for reference.
674 uint64_t properties = 0;
675
676 if (getExtTypeName().starts_with(kSpirvPrefix))
677 properties |=
678 (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
679
680 if (getExtTypeName() == kAMDGCNNamedBarrier)
681 properties |= LLVMTargetExtType::CanBeGlobal;
682
683 return (properties & prop) == prop;
684}
685
686bool LLVM::LLVMTargetExtType::supportsMemOps() const {
687 // See llvm/lib/IR/Type.cpp for reference.
688 if (getExtTypeName().starts_with(kSpirvPrefix))
689 return true;
690
691 if (getExtTypeName() == kArmSVCount)
692 return true;
693
694 return false;
695}
696
697//===----------------------------------------------------------------------===//
698// LLVMPPCFP128Type
699//===----------------------------------------------------------------------===//
700
701const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
702 return APFloat::PPCDoubleDouble();
703}
704
705//===----------------------------------------------------------------------===//
706// Utility functions.
707//===----------------------------------------------------------------------===//
708
709/// Check whether type is a compatible ptr type. These are pointer-like types
710/// with no element type, no metadata, and using the LLVM
711/// LLVMAddrSpaceAttrInterface memory space.
712static bool isCompatiblePtrType(Type type) {
713 auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
714 if (!ptrTy)
715 return false;
716 return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
717 isa<LLVMAddrSpaceAttrInterface>(ptrTy.getMemorySpace());
718}
719
721 // clang-format off
722 if (llvm::isa<
723 BFloat16Type,
724 Float16Type,
725 Float32Type,
726 Float64Type,
727 Float80Type,
728 Float128Type,
729 LLVMArrayType,
730 LLVMFunctionType,
731 LLVMLabelType,
732 LLVMMetadataType,
733 LLVMPPCFP128Type,
734 LLVMPointerType,
735 LLVMStructType,
736 LLVMTokenType,
737 LLVMTargetExtType,
738 LLVMVoidType,
739 LLVMX86AMXType
740 >(type)) {
741 // clang-format on
742 return true;
743 }
744
745 // Only signless integers are compatible.
746 if (auto intType = llvm::dyn_cast<IntegerType>(type))
747 return intType.isSignless();
748
749 // 1D vector types are compatible.
750 if (auto vecType = llvm::dyn_cast<VectorType>(type))
751 return vecType.getRank() == 1;
752
753 return isCompatiblePtrType(type);
754}
755
756static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
757 if (!compatibleTypes.insert(type).second)
758 return true;
759
760 auto isCompatible = [&](Type type) {
761 return isCompatibleImpl(type, compatibleTypes);
762 };
763
764 bool result =
766 .Case<LLVMStructType>([&](auto structType) {
767 return llvm::all_of(structType.getBody(), isCompatible);
768 })
769 .Case<LLVMFunctionType>([&](auto funcType) {
770 return isCompatible(funcType.getReturnType()) &&
771 llvm::all_of(funcType.getParams(), isCompatible);
772 })
773 .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
774 .Case<VectorType>([&](auto vecType) {
775 return vecType.getRank() == 1 &&
776 isCompatible(vecType.getElementType());
777 })
778 .Case<LLVMPointerType>([&](auto pointerType) { return true; })
779 .Case<LLVMTargetExtType>([&](auto extType) {
780 return llvm::all_of(extType.getTypeParams(), isCompatible);
781 })
782 // clang-format off
783 .Case<
784 LLVMArrayType
785 >([&](auto containerType) {
786 return isCompatible(containerType.getElementType());
787 })
788 .Case<
789 BFloat16Type,
790 Float16Type,
791 Float32Type,
792 Float64Type,
793 Float80Type,
794 Float128Type,
795 LLVMLabelType,
796 LLVMMetadataType,
797 LLVMPPCFP128Type,
798 LLVMTokenType,
799 LLVMVoidType,
800 LLVMX86AMXType
801 >([](Type) { return true; })
802 // clang-format on
803 .Case<PtrLikeTypeInterface>(
804 [](Type type) { return isCompatiblePtrType(type); })
805 .Default(false);
806
807 if (!result)
808 compatibleTypes.erase(type);
809
810 return result;
811}
812
813bool LLVMDialect::isCompatibleType(Type type) {
814 if (auto *llvmDialect =
815 type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
816 return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
817
818 DenseSet<Type> localCompatibleTypes;
819 return isCompatibleImpl(type, localCompatibleTypes);
820}
821
823 return LLVMDialect::isCompatibleType(type);
824}
825
827 return /*LLVM_PrimitiveType*/ (
829 !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
830 /*LLVM_OpaqueStruct*/
831 !(isa<LLVM::LLVMStructType>(type) &&
832 cast<LLVM::LLVMStructType>(type).isOpaque()) &&
833 /*LLVM_AnyTargetExt*/
834 !(isa<LLVM::LLVMTargetExtType>(type) &&
835 !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
836}
837
839 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
840 Float80Type, Float128Type, LLVMPPCFP128Type>(type);
841}
842
844 if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
845 if (vecType.getRank() != 1)
846 return false;
847 Type elementType = vecType.getElementType();
848 if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
849 return intType.isSignless();
850 return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
851 Float80Type, Float128Type, LLVMPointerType>(elementType) ||
852 isCompatiblePtrType(elementType);
853 }
854 return false;
855}
856
857llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
858 auto vecTy = dyn_cast<VectorType>(type);
859 assert(vecTy && "incompatible with LLVM vector type");
860 if (vecTy.isScalable())
861 return llvm::ElementCount::getScalable(vecTy.getNumElements());
862 return llvm::ElementCount::getFixed(vecTy.getNumElements());
863}
864
866 assert(llvm::isa<VectorType>(vectorType) &&
867 "expected LLVM-compatible vector type");
868 return llvm::cast<VectorType>(vectorType).isScalable();
869}
870
871Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
872 bool isScalable) {
873 assert(VectorType::isValidElementType(elementType) &&
874 "incompatible element type");
875 return VectorType::get(numElements, elementType, {isScalable});
876}
877
879 const llvm::ElementCount &numElements) {
880 if (numElements.isScalable())
881 return getVectorType(elementType, numElements.getKnownMinValue(),
882 /*isScalable=*/true);
883 return getVectorType(elementType, numElements.getFixedValue(),
884 /*isScalable=*/false);
885}
886
888 assert(isCompatibleType(type) &&
889 "expected a type compatible with the LLVM dialect");
890
892 .Case<BFloat16Type, Float16Type>(
893 [](Type) { return llvm::TypeSize::getFixed(16); })
894 .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
895 .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
896 .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
897 .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
898 .Case<IntegerType>([](IntegerType intTy) {
899 return llvm::TypeSize::getFixed(intTy.getWidth());
900 })
901 .Case<LLVMPPCFP128Type>(
902 [](Type) { return llvm::TypeSize::getFixed(128); })
903 .Case<VectorType>([](VectorType t) {
904 assert(isCompatibleVectorType(t) &&
905 "unexpected incompatible with LLVM vector type");
906 llvm::TypeSize elementSize =
907 getPrimitiveTypeSizeInBits(t.getElementType());
908 return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
909 elementSize.isScalable());
910 })
911 .Default([](Type ty) {
912 assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
913 LLVMTokenType, LLVMStructType, LLVMArrayType,
914 LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
915 ty)) &&
916 "unexpected missing support for primitive type");
917 return llvm::TypeSize::getFixed(0);
918 });
919}
920
921//===----------------------------------------------------------------------===//
922// LLVMDialect
923//===----------------------------------------------------------------------===//
924
925void LLVMDialect::registerTypes() {
926 addTypes<
927#define GET_TYPEDEF_LIST
928#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
929 >();
930}
931
932Type LLVMDialect::parseType(DialectAsmParser &parser) const {
933 return detail::parseType(parser);
934}
935
936void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
937 return detail::printType(type, os);
938}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
constexpr static const uint64_t kDefaultPointerAlignment
static void printExtTypeParams(AsmPrinter &p, ArrayRef< Type > typeParams, ArrayRef< unsigned int > intParams)
constexpr static const uint64_t kDefaultPointerSizeBits
static void printFunctionTypes(AsmPrinter &p, ArrayRef< Type > params, bool isVarArg)
Definition LLVMTypes.cpp:66
static bool isCompatibleImpl(Type type, DenseSet< Type > &compatibleTypes)
static ParseResult parseExtTypeParams(AsmParser &p, SmallVectorImpl< Type > &typeParams, SmallVectorImpl< unsigned int > &intParams)
Parses the parameter list for a target extension type.
Definition LLVMTypes.cpp:90
static constexpr llvm::StringRef kSpirvPrefix
static std::optional< uint64_t > getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, PtrDLEntryPos pos)
Returns the part of the data layout entry that corresponds to pos for the given type by interpreting ...
static bool isCompatiblePtrType(Type type)
Check whether type is a compatible ptr type.
static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
These are unused for now.
static ParseResult parseFunctionTypes(AsmParser &p, SmallVector< Type > &params, bool &isVarArg)
Definition LLVMTypes.cpp:36
constexpr static const uint64_t kBitsInByte
Definition LLVMTypes.cpp:30
static constexpr llvm::StringRef kArmSVCount
static constexpr llvm::StringRef kAMDGCNNamedBarrier
static LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
static std::optional< uint64_t > getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos)
static uint64_t calculateStructAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
b getContext())
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition Attributes.h:25
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This class represents a diagnostic that is inflight and set to be reported.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
llvm::TypeSize getPrimitiveTypeSizeInBits(Type type)
Returns the size of the given primitive LLVM dialect-compatible type (including vectors) in bits,...
void printPrettyLLVMType(AsmPrinter &p, Type type)
Print any MLIR type or a concise syntax for LLVM types.
bool isLoadableType(Type type)
Returns true if the given type is a loadable type compatible with the LLVM dialect.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
ParseResult parsePrettyLLVMType(AsmParser &p, Type &type)
Parse any MLIR type or a concise syntax for LLVM types.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
PtrDLEntryPos
The positions of different values in the data layout entry for pointers.
Definition LLVMTypes.h:146
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
::llvm::MapVector<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface > DataLayoutIdentifiedEntryMap
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::ArrayRef< DataLayoutEntryInterface > DataLayoutEntryListRef
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152