MLIR  22.0.0git
LLVMTypes.cpp
Go to the documentation of this file.
1 //===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
10 // correspond to the LLVM IR type system.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TypeDetail.h"
15 
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/TypeSupport.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/TypeSize.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::LLVM;
29 
30 constexpr const static uint64_t kBitsInByte = 8;
31 
32 //===----------------------------------------------------------------------===//
33 // custom<FunctionTypes>
34 //===----------------------------------------------------------------------===//
35 
36 static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37  bool &isVarArg) {
38  isVarArg = false;
39  // `(` `)`
40  if (succeeded(p.parseOptionalRParen()))
41  return success();
42 
43  // `(` `...` `)`
44  if (succeeded(p.parseOptionalEllipsis())) {
45  isVarArg = true;
46  return p.parseRParen();
47  }
48 
49  // type (`,` type)* (`,` `...`)?
50  Type type;
51  if (parsePrettyLLVMType(p, type))
52  return failure();
53  params.push_back(type);
54  while (succeeded(p.parseOptionalComma())) {
55  if (succeeded(p.parseOptionalEllipsis())) {
56  isVarArg = true;
57  return p.parseRParen();
58  }
59  if (parsePrettyLLVMType(p, type))
60  return failure();
61  params.push_back(type);
62  }
63  return p.parseRParen();
64 }
65 
67  bool isVarArg) {
68  llvm::interleaveComma(params, p,
69  [&](Type type) { printPrettyLLVMType(p, type); });
70  if (isVarArg) {
71  if (!params.empty())
72  p << ", ";
73  p << "...";
74  }
75  p << ')';
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // custom<ExtTypeParams>
80 //===----------------------------------------------------------------------===//
81 
82 /// Parses the parameter list for a target extension type. The parameter list
83 /// contains an optional list of type parameters, followed by an optional list
84 /// of integer parameters. Type and integer parameters cannot be interleaved in
85 /// the list.
86 /// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87 /// typeList ::= type ("," type)*
88 /// intList ::= integer ("," integer)*
89 static ParseResult
91  SmallVectorImpl<unsigned int> &intParams) {
92  bool parseType = true;
93  auto typeOrIntParser = [&]() -> ParseResult {
94  unsigned int i;
95  auto intResult = p.parseOptionalInteger(i);
96  if (intResult.has_value() && !failed(*intResult)) {
97  // Successfully parsed an integer.
98  intParams.push_back(i);
99  // After the first integer was successfully parsed, no
100  // more types can be parsed.
101  parseType = false;
102  return success();
103  }
104  if (parseType) {
105  Type t;
106  if (!parsePrettyLLVMType(p, t)) {
107  // Successfully parsed a type.
108  typeParams.push_back(t);
109  return success();
110  }
111  }
112  return failure();
113  };
114  if (p.parseCommaSeparatedList(typeOrIntParser)) {
116  "failed to parse parameter list for target extension type");
117  return failure();
118  }
119  return success();
120 }
121 
122 static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
123  ArrayRef<unsigned int> intParams) {
124  p << typeParams;
125  if (!typeParams.empty() && !intParams.empty())
126  p << ", ";
127 
128  p << intParams;
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // ODS-Generated Definitions
133 //===----------------------------------------------------------------------===//
134 
135 /// These are unused for now.
136 /// TODO: Move over to these once more types have been migrated to TypeDef.
137 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
138 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139 LLVM_ATTRIBUTE_UNUSED static LogicalResult
141 
142 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143 
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146 
147 //===----------------------------------------------------------------------===//
148 // LLVMArrayType
149 //===----------------------------------------------------------------------===//
150 
151 bool LLVMArrayType::isValidElementType(Type type) {
152  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153  LLVMFunctionType, LLVMTokenType>(type);
154 }
155 
156 LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
157  assert(elementType && "expected non-null subtype");
158  return Base::get(elementType.getContext(), elementType, numElements);
159 }
160 
161 LLVMArrayType
162 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
163  Type elementType, uint64_t numElements) {
164  assert(elementType && "expected non-null subtype");
165  return Base::getChecked(emitError, elementType.getContext(), elementType,
166  numElements);
167 }
168 
169 LogicalResult
171  Type elementType, uint64_t numElements) {
172  if (!isValidElementType(elementType))
173  return emitError() << "invalid array element type: " << elementType;
174  return success();
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // DataLayoutTypeInterface
179 //===----------------------------------------------------------------------===//
180 
181 llvm::TypeSize
182 LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183  DataLayoutEntryListRef params) const {
184  return llvm::TypeSize::getFixed(kBitsInByte *
185  getTypeSize(dataLayout, params));
186 }
187 
188 llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189  DataLayoutEntryListRef params) const {
190  return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191  dataLayout.getTypeABIAlignment(getElementType())) *
192  getNumElements();
193 }
194 
195 uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196  DataLayoutEntryListRef params) const {
197  return dataLayout.getTypeABIAlignment(getElementType());
198 }
199 
200 uint64_t
201 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202  DataLayoutEntryListRef params) const {
203  return dataLayout.getTypePreferredAlignment(getElementType());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // Function type.
208 //===----------------------------------------------------------------------===//
209 
210 bool LLVMFunctionType::isValidArgumentType(Type type) {
211  return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
212 }
213 
214 bool LLVMFunctionType::isValidResultType(Type type) {
215  return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
216 }
217 
218 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
219  bool isVarArg) {
220  assert(result && "expected non-null result");
221  return Base::get(result.getContext(), result, arguments, isVarArg);
222 }
223 
224 LLVMFunctionType
225 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226  Type result, ArrayRef<Type> arguments,
227  bool isVarArg) {
228  assert(result && "expected non-null result");
229  return Base::getChecked(emitError, result.getContext(), result, arguments,
230  isVarArg);
231 }
232 
233 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234  TypeRange results) const {
235  if (results.size() != 1 || !isValidResultType(results[0]))
236  return {};
237  if (!llvm::all_of(inputs, isValidArgumentType))
238  return {};
239  return get(results[0], llvm::to_vector(inputs), isVarArg());
240 }
241 
243  return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
244 }
245 
246 LogicalResult
248  Type result, ArrayRef<Type> arguments, bool) {
249  if (!isValidResultType(result))
250  return emitError() << "invalid function result type: " << result;
251 
252  for (Type arg : arguments)
253  if (!isValidArgumentType(arg))
254  return emitError() << "invalid function argument type: " << arg;
255 
256  return success();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // DataLayoutTypeInterface
261 //===----------------------------------------------------------------------===//
262 
263 constexpr const static uint64_t kDefaultPointerSizeBits = 64;
264 constexpr const static uint64_t kDefaultPointerAlignment = 8;
265 
266 std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
267  PtrDLEntryPos pos) {
268  auto spec = cast<DenseIntElementsAttr>(attr);
269  auto idx = static_cast<int64_t>(pos);
270  if (idx >= spec.size())
271  return std::nullopt;
272  return spec.getValues<uint64_t>()[idx];
273 }
274 
275 /// Returns the part of the data layout entry that corresponds to `pos` for the
276 /// given `type` by interpreting the list of entries `params`. For the pointer
277 /// type in the default address space, returns the default value if the entries
278 /// do not provide a custom one, for other address spaces returns std::nullopt.
279 static std::optional<uint64_t>
281  PtrDLEntryPos pos) {
282  // First, look for the entry for the pointer in the current address space.
283  Attribute currentEntry;
284  for (DataLayoutEntryInterface entry : params) {
285  if (!entry.isTypeEntry())
286  continue;
287  if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
288  type.getAddressSpace()) {
289  currentEntry = entry.getValue();
290  break;
291  }
292  }
293  if (currentEntry) {
294  std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
295  // If the optional `PtrDLEntryPos::Index` entry is not available, use the
296  // pointer size as the index bitwidth.
297  if (!value && pos == PtrDLEntryPos::Index)
298  value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
299  bool isSizeOrIndex =
300  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
301  return *value / (isSizeOrIndex ? 1 : kBitsInByte);
302  }
303 
304  // If not found, and this is the pointer to the default memory space, assume
305  // 64-bit pointers.
306  if (type.getAddressSpace() == 0) {
307  bool isSizeOrIndex =
308  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
309  return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
310  }
311 
312  return std::nullopt;
313 }
314 
315 llvm::TypeSize
316 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
317  DataLayoutEntryListRef params) const {
318  if (std::optional<uint64_t> size =
320  return llvm::TypeSize::getFixed(*size);
321 
322  // For other memory spaces, use the size of the pointer to the default memory
323  // space.
324  return dataLayout.getTypeSizeInBits(get(getContext()));
325 }
326 
327 uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
328  DataLayoutEntryListRef params) const {
329  if (std::optional<uint64_t> alignment =
331  return *alignment;
332 
333  return dataLayout.getTypeABIAlignment(get(getContext()));
334 }
335 
336 uint64_t
337 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
338  DataLayoutEntryListRef params) const {
339  if (std::optional<uint64_t> alignment =
341  return *alignment;
342 
343  return dataLayout.getTypePreferredAlignment(get(getContext()));
344 }
345 
346 std::optional<uint64_t>
348  DataLayoutEntryListRef params) const {
349  if (std::optional<uint64_t> indexBitwidth =
351  return *indexBitwidth;
352 
353  return dataLayout.getTypeIndexBitwidth(get(getContext()));
354 }
355 
356 bool LLVMPointerType::areCompatible(
357  DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
358  DataLayoutSpecInterface newSpec,
359  const DataLayoutIdentifiedEntryMap &map) const {
360  for (DataLayoutEntryInterface newEntry : newLayout) {
361  if (!newEntry.isTypeEntry())
362  continue;
363  uint64_t size = kDefaultPointerSizeBits;
364  uint64_t abi = kDefaultPointerAlignment;
365  auto newType =
366  llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
367  const auto *it =
368  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
369  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
370  return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
371  newType.getAddressSpace();
372  }
373  return false;
374  });
375  if (it == oldLayout.end()) {
376  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
377  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
378  return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
379  }
380  return false;
381  });
382  }
383  if (it != oldLayout.end()) {
386  }
387 
388  Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
389  uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
390  uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
391  if (size != newSize || abi < newAbi || abi % newAbi != 0)
392  return false;
393  }
394  return true;
395 }
396 
398  Location loc) const {
399  for (DataLayoutEntryInterface entry : entries) {
400  if (!entry.isTypeEntry())
401  continue;
402  auto key = llvm::cast<Type>(entry.getKey());
403  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
404  if (!values || (values.size() != 3 && values.size() != 4)) {
405  return emitError(loc)
406  << "expected layout attribute for " << key
407  << " to be a dense integer elements attribute with 3 or 4 "
408  "elements";
409  }
410  if (!values.getElementType().isInteger(64))
411  return emitError(loc) << "expected i64 parameters for " << key;
412 
415  return emitError(loc) << "preferred alignment is expected to be at least "
416  "as large as ABI alignment";
417  }
418  }
419  return success();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // Struct type.
424 //===----------------------------------------------------------------------===//
425 
426 bool LLVMStructType::isValidElementType(Type type) {
427  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
428  LLVMFunctionType, LLVMTokenType>(type);
429 }
430 
431 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
432  StringRef name) {
433  return Base::get(context, name, /*opaque=*/false);
434 }
435 
436 LLVMStructType LLVMStructType::getIdentifiedChecked(
438  StringRef name) {
439  return Base::getChecked(emitError, context, name, /*opaque=*/false);
440 }
441 
442 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
443  StringRef name,
444  ArrayRef<Type> elements,
445  bool isPacked) {
446  std::string stringName = name.str();
447  unsigned counter = 0;
448  do {
449  auto type = LLVMStructType::getIdentified(context, stringName);
450  if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
451  counter += 1;
452  stringName = (Twine(name) + "." + std::to_string(counter)).str();
453  continue;
454  }
455  return type;
456  } while (true);
457 }
458 
459 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
460  ArrayRef<Type> types, bool isPacked) {
461  return Base::get(context, types, isPacked);
462 }
463 
464 LLVMStructType
465 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
466  MLIRContext *context, ArrayRef<Type> types,
467  bool isPacked) {
468  return Base::getChecked(emitError, context, types, isPacked);
469 }
470 
471 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
472  return Base::get(context, name, /*opaque=*/true);
473 }
474 
475 LLVMStructType
476 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
477  MLIRContext *context, StringRef name) {
478  return Base::getChecked(emitError, context, name, /*opaque=*/true);
479 }
480 
481 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
482  assert(isIdentified() && "can only set bodies of identified structs");
483  assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
484  "expected valid body types");
485  return Base::mutate(types, isPacked);
486 }
487 
488 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
489 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
490 bool LLVMStructType::isOpaque() const {
491  return getImpl()->isIdentified() &&
492  (getImpl()->isOpaque() || !getImpl()->isInitialized());
493 }
494 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
495 StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
496 ArrayRef<Type> LLVMStructType::getBody() const {
497  return isIdentified() ? getImpl()->getIdentifiedStructBody()
498  : getImpl()->getTypeList();
499 }
500 
501 LogicalResult
502 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
503  bool) {
504  return success();
505 }
506 
507 LogicalResult
508 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
509  ArrayRef<Type> types, bool) {
510  for (Type t : types)
511  if (!isValidElementType(t))
512  return emitError() << "invalid LLVM structure element type: " << t;
513 
514  return success();
515 }
516 
517 llvm::TypeSize
518 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
519  DataLayoutEntryListRef params) const {
520  auto structSize = llvm::TypeSize::getFixed(0);
521  uint64_t structAlignment = 1;
522  for (Type element : getBody()) {
523  uint64_t elementAlignment =
524  isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
525  // Add padding to the struct size to align it to the abi alignment of the
526  // element type before than adding the size of the element.
527  structSize = llvm::alignTo(structSize, elementAlignment);
528  structSize += dataLayout.getTypeSize(element);
529 
530  // The alignment requirement of a struct is equal to the strictest alignment
531  // requirement of its elements.
532  structAlignment = std::max(elementAlignment, structAlignment);
533  }
534  // At the end, add padding to the struct to satisfy its own alignment
535  // requirement. Otherwise structs inside of arrays would be misaligned.
536  structSize = llvm::alignTo(structSize, structAlignment);
537  return structSize * kBitsInByte;
538 }
539 
540 namespace {
541 enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
542 } // namespace
543 
544 static std::optional<uint64_t>
546  StructDLEntryPos pos) {
547  const auto *currentEntry =
548  llvm::find_if(params, [](DataLayoutEntryInterface entry) {
549  return entry.isTypeEntry();
550  });
551  if (currentEntry == params.end())
552  return std::nullopt;
553 
554  auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
555  if (pos == StructDLEntryPos::Preferred &&
556  attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
557  // If no preferred was specified, fall back to abi alignment
558  pos = StructDLEntryPos::Abi;
559 
560  return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
561 }
562 
563 static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
564  DataLayoutEntryListRef params,
565  LLVMStructType type,
566  StructDLEntryPos pos) {
567  // Packed structs always have an abi alignment of 1
568  if (pos == StructDLEntryPos::Abi && type.isPacked()) {
569  return 1;
570  }
571 
572  // The alignment requirement of a struct is equal to the strictest alignment
573  // requirement of its elements.
574  uint64_t structAlignment = 1;
575  for (Type iter : type.getBody()) {
576  structAlignment =
577  std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
578  }
579 
580  // Entries are only allowed to be stricter than the required alignment
581  if (std::optional<uint64_t> entryResult =
582  getStructDataLayoutEntry(params, type, pos))
583  return std::max(*entryResult / kBitsInByte, structAlignment);
584 
585  return structAlignment;
586 }
587 
588 uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
589  DataLayoutEntryListRef params) const {
590  return calculateStructAlignment(dataLayout, params, *this,
591  StructDLEntryPos::Abi);
592 }
593 
594 uint64_t
595 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
596  DataLayoutEntryListRef params) const {
597  return calculateStructAlignment(dataLayout, params, *this,
598  StructDLEntryPos::Preferred);
599 }
600 
601 static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
602  return llvm::cast<DenseIntElementsAttr>(attr)
603  .getValues<uint64_t>()[static_cast<size_t>(pos)];
604 }
605 
606 bool LLVMStructType::areCompatible(
607  DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
608  DataLayoutSpecInterface newSpec,
609  const DataLayoutIdentifiedEntryMap &map) const {
610  for (DataLayoutEntryInterface newEntry : newLayout) {
611  if (!newEntry.isTypeEntry())
612  continue;
613 
614  const auto *previousEntry =
615  llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
616  return entry.isTypeEntry();
617  });
618  if (previousEntry == oldLayout.end())
619  continue;
620 
621  uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
622  StructDLEntryPos::Abi);
623  uint64_t newAbi =
624  extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
625  if (abi < newAbi || abi % newAbi != 0)
626  return false;
627  }
628  return true;
629 }
630 
632  Location loc) const {
633  for (DataLayoutEntryInterface entry : entries) {
634  if (!entry.isTypeEntry())
635  continue;
636 
637  auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
638  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
639  if (!values || (values.size() != 2 && values.size() != 1)) {
640  return emitError(loc)
641  << "expected layout attribute for "
642  << llvm::cast<Type>(entry.getKey())
643  << " to be a dense integer elements attribute of 1 or 2 elements";
644  }
645  if (!values.getElementType().isInteger(64))
646  return emitError(loc) << "expected i64 entries for " << key;
647 
648  if (key.isIdentified() || !key.getBody().empty()) {
649  return emitError(loc) << "unexpected layout attribute for struct " << key;
650  }
651 
652  if (values.size() == 1)
653  continue;
654 
655  if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
656  extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
657  return emitError(loc) << "preferred alignment is expected to be at least "
658  "as large as ABI alignment";
659  }
660  }
661  return mlir::success();
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // LLVMTargetExtType.
666 //===----------------------------------------------------------------------===//
667 
668 static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
669 static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
670 
671 bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
672  // See llvm/lib/IR/Type.cpp for reference.
673  uint64_t properties = 0;
674 
675  if (getExtTypeName().starts_with(kSpirvPrefix))
676  properties |=
677  (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
678 
679  return (properties & prop) == prop;
680 }
681 
682 bool LLVM::LLVMTargetExtType::supportsMemOps() const {
683  // See llvm/lib/IR/Type.cpp for reference.
684  if (getExtTypeName().starts_with(kSpirvPrefix))
685  return true;
686 
687  if (getExtTypeName() == kArmSVCount)
688  return true;
689 
690  return false;
691 }
692 
693 //===----------------------------------------------------------------------===//
694 // LLVMPPCFP128Type
695 //===----------------------------------------------------------------------===//
696 
697 const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
698  return APFloat::PPCDoubleDouble();
699 }
700 
701 //===----------------------------------------------------------------------===//
702 // Utility functions.
703 //===----------------------------------------------------------------------===//
704 
705 /// Check whether type is a compatible ptr type. These are pointer-like types
706 /// with no element type, no metadata, and using the LLVM AddressSpaceAttr
707 /// memory space.
708 static bool isCompatiblePtrType(Type type) {
709  auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
710  if (!ptrTy)
711  return false;
712  return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
713  isa<AddressSpaceAttr>(ptrTy.getMemorySpace());
714 }
715 
717  // clang-format off
718  if (llvm::isa<
719  BFloat16Type,
720  Float16Type,
721  Float32Type,
722  Float64Type,
723  Float80Type,
724  Float128Type,
725  LLVMArrayType,
726  LLVMFunctionType,
727  LLVMLabelType,
728  LLVMMetadataType,
729  LLVMPPCFP128Type,
730  LLVMPointerType,
731  LLVMStructType,
732  LLVMTokenType,
733  LLVMTargetExtType,
734  LLVMVoidType,
735  LLVMX86AMXType
736  >(type)) {
737  // clang-format on
738  return true;
739  }
740 
741  // Only signless integers are compatible.
742  if (auto intType = llvm::dyn_cast<IntegerType>(type))
743  return intType.isSignless();
744 
745  // 1D vector types are compatible.
746  if (auto vecType = llvm::dyn_cast<VectorType>(type))
747  return vecType.getRank() == 1;
748 
749  return isCompatiblePtrType(type);
750 }
751 
752 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
753  if (!compatibleTypes.insert(type).second)
754  return true;
755 
756  auto isCompatible = [&](Type type) {
757  return isCompatibleImpl(type, compatibleTypes);
758  };
759 
760  bool result =
762  .Case<LLVMStructType>([&](auto structType) {
763  return llvm::all_of(structType.getBody(), isCompatible);
764  })
765  .Case<LLVMFunctionType>([&](auto funcType) {
766  return isCompatible(funcType.getReturnType()) &&
767  llvm::all_of(funcType.getParams(), isCompatible);
768  })
769  .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
770  .Case<VectorType>([&](auto vecType) {
771  return vecType.getRank() == 1 &&
772  isCompatible(vecType.getElementType());
773  })
774  .Case<LLVMPointerType>([&](auto pointerType) { return true; })
775  .Case<LLVMTargetExtType>([&](auto extType) {
776  return llvm::all_of(extType.getTypeParams(), isCompatible);
777  })
778  // clang-format off
779  .Case<
780  LLVMArrayType
781  >([&](auto containerType) {
782  return isCompatible(containerType.getElementType());
783  })
784  .Case<
785  BFloat16Type,
786  Float16Type,
787  Float32Type,
788  Float64Type,
789  Float80Type,
790  Float128Type,
791  LLVMLabelType,
792  LLVMMetadataType,
793  LLVMPPCFP128Type,
794  LLVMTokenType,
795  LLVMVoidType,
796  LLVMX86AMXType
797  >([](Type) { return true; })
798  // clang-format on
799  .Case<PtrLikeTypeInterface>(
800  [](Type type) { return isCompatiblePtrType(type); })
801  .Default([](Type) { return false; });
802 
803  if (!result)
804  compatibleTypes.erase(type);
805 
806  return result;
807 }
808 
810  if (auto *llvmDialect =
811  type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
812  return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
813 
814  DenseSet<Type> localCompatibleTypes;
815  return isCompatibleImpl(type, localCompatibleTypes);
816 }
817 
819  return LLVMDialect::isCompatibleType(type);
820 }
821 
823  return /*LLVM_PrimitiveType*/ (
825  !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
826  /*LLVM_OpaqueStruct*/
827  !(isa<LLVM::LLVMStructType>(type) &&
828  cast<LLVM::LLVMStructType>(type).isOpaque()) &&
829  /*LLVM_AnyTargetExt*/
830  !(isa<LLVM::LLVMTargetExtType>(type) &&
831  !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
832 }
833 
835  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
836  Float80Type, Float128Type, LLVMPPCFP128Type>(type);
837 }
838 
840  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
841  if (vecType.getRank() != 1)
842  return false;
843  Type elementType = vecType.getElementType();
844  if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
845  return intType.isSignless();
846  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
847  Float80Type, Float128Type, LLVMPointerType>(elementType) ||
848  isCompatiblePtrType(elementType);
849  }
850  return false;
851 }
852 
853 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
854  auto vecTy = dyn_cast<VectorType>(type);
855  assert(vecTy && "incompatible with LLVM vector type");
856  if (vecTy.isScalable())
857  return llvm::ElementCount::getScalable(vecTy.getNumElements());
858  return llvm::ElementCount::getFixed(vecTy.getNumElements());
859 }
860 
862  assert(llvm::isa<VectorType>(vectorType) &&
863  "expected LLVM-compatible vector type");
864  return llvm::cast<VectorType>(vectorType).isScalable();
865 }
866 
867 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
868  bool isScalable) {
869  assert(VectorType::isValidElementType(elementType) &&
870  "incompatible element type");
871  return VectorType::get(numElements, elementType, {isScalable});
872 }
873 
875  const llvm::ElementCount &numElements) {
876  if (numElements.isScalable())
877  return getVectorType(elementType, numElements.getKnownMinValue(),
878  /*isScalable=*/true);
879  return getVectorType(elementType, numElements.getFixedValue(),
880  /*isScalable=*/false);
881 }
882 
884  assert(isCompatibleType(type) &&
885  "expected a type compatible with the LLVM dialect");
886 
888  .Case<BFloat16Type, Float16Type>(
889  [](Type) { return llvm::TypeSize::getFixed(16); })
890  .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
891  .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
892  .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
893  .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
894  .Case<IntegerType>([](IntegerType intTy) {
895  return llvm::TypeSize::getFixed(intTy.getWidth());
896  })
897  .Case<LLVMPPCFP128Type>(
898  [](Type) { return llvm::TypeSize::getFixed(128); })
899  .Case<VectorType>([](VectorType t) {
900  assert(isCompatibleVectorType(t) &&
901  "unexpected incompatible with LLVM vector type");
902  llvm::TypeSize elementSize =
903  getPrimitiveTypeSizeInBits(t.getElementType());
904  return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
905  elementSize.isScalable());
906  })
907  .Default([](Type ty) {
908  assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
909  LLVMTokenType, LLVMStructType, LLVMArrayType,
910  LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
911  ty)) &&
912  "unexpected missing support for primitive type");
913  return llvm::TypeSize::getFixed(0);
914  });
915 }
916 
917 //===----------------------------------------------------------------------===//
918 // LLVMDialect
919 //===----------------------------------------------------------------------===//
920 
921 void LLVMDialect::registerTypes() {
922  addTypes<
923 #define GET_TYPEDEF_LIST
924 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
925  >();
926 }
927 
929  return detail::parseType(parser);
930 }
931 
932 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
933  return detail::printType(type, os);
934 }
static LogicalResult verifyEntries(function_ref< InFlightDiagnostic()> emitError, ArrayRef< DataLayoutEntryInterface > entries, bool allowTypes=true)
Verify entries, with the option to disallow types as keys.
Definition: DLTI.cpp:134
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
constexpr static const uint64_t kDefaultPointerAlignment
Definition: LLVMTypes.cpp:264
static void printExtTypeParams(AsmPrinter &p, ArrayRef< Type > typeParams, ArrayRef< unsigned int > intParams)
Definition: LLVMTypes.cpp:122
constexpr static const uint64_t kDefaultPointerSizeBits
Definition: LLVMTypes.cpp:263
static LLVM_ATTRIBUTE_UNUSED OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
These are unused for now.
static void printFunctionTypes(AsmPrinter &p, ArrayRef< Type > params, bool isVarArg)
Definition: LLVMTypes.cpp:66
static bool isCompatibleImpl(Type type, DenseSet< Type > &compatibleTypes)
Definition: LLVMTypes.cpp:752
static ParseResult parseExtTypeParams(AsmParser &p, SmallVectorImpl< Type > &typeParams, SmallVectorImpl< unsigned int > &intParams)
Parses the parameter list for a target extension type.
Definition: LLVMTypes.cpp:90
static LLVM_ATTRIBUTE_UNUSED LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
static constexpr llvm::StringRef kSpirvPrefix
Definition: LLVMTypes.cpp:668
static bool isCompatiblePtrType(Type type)
Check whether type is a compatible ptr type.
Definition: LLVMTypes.cpp:708
static ParseResult parseFunctionTypes(AsmParser &p, SmallVector< Type > &params, bool &isVarArg)
Definition: LLVMTypes.cpp:36
constexpr static const uint64_t kBitsInByte
Definition: LLVMTypes.cpp:30
static constexpr llvm::StringRef kArmSVCount
Definition: LLVMTypes.cpp:669
static std::optional< uint64_t > getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, PtrDLEntryPos pos)
Returns the part of the data layout entry that corresponds to pos for the given type by interpreting ...
Definition: LLVMTypes.cpp:280
static std::optional< uint64_t > getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:545
static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:601
static uint64_t calculateStructAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:563
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:867
llvm::TypeSize getPrimitiveTypeSizeInBits(Type type)
Returns the size of the given primitive LLVM dialect-compatible type (including vectors) in bits,...
Definition: LLVMTypes.cpp:883
void printPrettyLLVMType(AsmPrinter &p, Type type)
Print any MLIR type or a concise syntax for LLVM types.
bool isLoadableType(Type type)
Returns true if the given type is a loadable type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:822
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
Definition: LLVMTypes.cpp:861
ParseResult parsePrettyLLVMType(AsmParser &p, Type &type)
Parse any MLIR type or a concise syntax for LLVM types.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:839
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
Definition: LLVMTypes.cpp:716
PtrDLEntryPos
The positions of different values in the data layout entry for pointers.
Definition: LLVMTypes.h:146
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
Definition: LLVMTypes.cpp:266
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:834
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:853
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
::llvm::MapVector<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface > DataLayoutIdentifiedEntryMap
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423